# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import math
import altair as alt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import streamlit as st
from loguru import logger
from plotly.subplots import make_subplots
import diva.tools as tools
from diva import parameters
from diva.graphs.graph_generator import GraphTexts, IGraphGenerator
from diva.logging.logger import crossing_point, Process
[docs]
class ServiceGeneratePlotlyGraph(IGraphGenerator):
[docs]
@crossing_point("""Generate the graphic : This method calls other methods to build all the components of the graph""")
def generate(self, show=False):
"""
Generates a graph based on the provided parameters and graph type.
The method handles various conditions:
- Bounds and errors related to time and location.
- Generates the appropriate type of graph (e.g., line chart, bar chart, histogram) based on the parameters.
- Handles comparisons based on time and location if specified.
Parameters
----------
show : bool, optional
Whether to display the generated graph in test mode. Default is False.
Returns
-------
None
"""
# ---- read data + plot figs ----
if self.checks():
return None
self.fig = go.Figure()
self.matching_case()
self.__layout()
self.__plot(show=show)
self.__get_context()
# ---------------------------------------------------------------------------------
# --------------------------------- graphs ----------------------------------------
[docs]
@crossing_point("""Generate the raw line chart""")
def generate_lineplot(self, comparison_type="single"):
"""
Generate a line chart for raw data with options to compare locations, times, or both.
Args:
compare_loc (bool): Whether to compare multiple locations.
compare_time (bool): Whether to compare multiple time periods.
"""
df = self.data_collection.get_values()
df = self.add_tense_data(df)
x_list = [df[f"time_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)]
y_list = [df[f"values_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)]
tense_list = [df[f"tense_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)] # New column
# Create layout with optional multiple x-axes for time comparisons
layout = go.Layout(
template="plotly_dark",
xaxis=go.layout.XAxis(tickformat="%B", rangeslider=dict(visible=True)),
)
if "time" in comparison_type:
for i in range(1, self.NB_PARAM_ELM):
layout[f"xaxis{i + 1}"] = go.layout.XAxis(
overlaying="x", side="top", showticklabels=False
)
fig = go.Figure(layout=layout)
line_styles = {
"past": {"color": self.colors, "mode": "lines", "dash": "solid", "width": None},
"projection": {"color": self.colors, "mode": "lines", "dash": "dot", "width": 1}}
for i in range(len(x_list)):
for tense in ["past", "projection"]:
mask = (tense_list[i] == tense)
if mask.sum() != 0:
trace = go.Scatter(
x=x_list[i][mask],
y=y_list[i][mask],
mode=line_styles[tense]["mode"],
line=dict(color=line_styles[tense]["color"][i],
dash=line_styles[tense]["dash"],
width=line_styles[tense]["width"]),
name=f"{self.__get_chart_name(type_='line chart', idx=i)['raw']} ({tense.capitalize()})",
xaxis=f"x{i + 1}" if "time" in comparison_type and i > 0 else "x",
)
fig.add_trace(trace)
self.fig = fig
[docs]
@crossing_point("""Generate the line chart with average""")
def generate_lineplot_mean(self, comparison_type="single"):
"""
Display the line chart with average data and min/max corridor.
comparison_type:
- "single": 1 location, 1 period
- "location": Comparison of locations
- "time": Comparison of times
- "location_time": Comparison of both location and time
"""
if self.var_name == 'tp':
df_mean = self.data_collection.get_sum(aggreg_type=self.aggreg_type)
cols_val = [col for col in df_mean.columns if "values" in col]
df_min, df_max = df_mean.copy(), df_mean.copy()
df_min.loc[:, cols_val], df_max.loc[:, cols_val] = np.nan, np.nan
else:
df_mean = self.data_collection.get_mean(aggreg_type=self.aggreg_type)
df_min = self.data_collection.get_min(aggreg_type=self.aggreg_type)
df_max = self.data_collection.get_max(aggreg_type=self.aggreg_type)
df_mean = self.add_tense_data(df_mean)
x_list = [df_mean[f'time_collection_{k + 1}'].dropna() for k in range(self.NB_PARAM_ELM)]
y_mean_list = [df_mean[f'values_collection_{k + 1}'].dropna() for k in range(self.NB_PARAM_ELM)]
y_min_list = [df_min[f'values_collection_{k + 1}'].dropna() for k in range(self.NB_PARAM_ELM)]
y_max_list = [df_max[f'values_collection_{k + 1}'].dropna() for k in range(self.NB_PARAM_ELM)]
tense_list = [df_mean[f"tense_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)] # New column
# Layout for multiple x-axes if comparing times
layout = go.Layout(
xaxis=go.layout.XAxis(tickformat="%B"), template="plotly_dark"
)
if "time" in comparison_type:
for i in range(1, self.NB_PARAM_ELM):
layout[f"xaxis{i + 1}"] = go.layout.XAxis(
overlaying="x", side="top", showticklabels=False
)
fig = go.Figure(layout=layout)
line_styles = {
"past": {"color": self.colors, "mode": "lines", "dash": None},
"projection": {"color": self.colors, "mode": "lines", "dash": "2px,6px"}}
# Add traces for mean, min, and max
for i in range(len(x_list)):
for tense in ["past", "projection"]:
mask = (tense_list[i] == tense)
if mask.sum() != 0:
if tense=="past":
mask[mask.sum()] = True
# Mean trace
fig.add_trace(
go.Scatter(
x=x_list[i][mask],
y=y_mean_list[i][mask],
mode="lines",
name=f"{self.__get_chart_name(type_='line chart', idx=i)['mean']} ({tense.capitalize()})",
line=dict(color=line_styles[tense]["color"][i], dash=line_styles[tense]["dash"]),
xaxis=f"x{i + 1}" if "time" in comparison_type and i > 0 else "x",
)
)
fig.add_trace(
go.Scatter(
x=x_list[i][mask],
y=y_max_list[i][mask],
name=f"Max ({tense.capitalize()})",
mode="lines",
marker=dict(color="#444"),
line=dict(width=0, color="white"),
showlegend=False,
xaxis=f"x{i + 1}" if "time" in comparison_type and i > 0 else "x",
)
)
fig.add_trace(
go.Scatter(
x=x_list[i][mask],
y=y_min_list[i][mask],
name=self.__get_chart_name(type_="line chart", idx=i)["min_max"]+ f"({tense.capitalize()})",
marker=dict(color="#444"),
mode="lines",
line=dict(width=0, color="white"),
fillcolor="rgba(211, 211, 211, 0.5)",
fill="tonexty",
showlegend=True,
xaxis=f"x{i + 1}" if "time" in comparison_type and i > 0 else "x",
)
)
self.fig = fig
[docs]
@crossing_point("""Generate the line chart with monthly average for each year""")
def generate_rainbow(self, comparison_type="single"):
data = self.data_collection.get_temporal_agreg_data(t_agreg=self.aggreg_type)
if comparison_type =='single' :
self.__make_rainbow(data[0])
else:
self.__make_subplots_rainbow(comparison=comparison_type, data=data)
[docs]
@crossing_point("""Generate the bar chart""")
def generate_barplot(self, comparison_type="single"):
"""
Display the bar chart with average data.
comparison_type:
- "single": 1 location, 1 period
- "location": Comparison of locations
- "time": Comparison of times
- "location_time": Comparison of both location and time
"""
if self.var_name == 'tp':
df_mean = self.data_collection.get_sum(aggreg_type=self.aggreg_type)
else:
df_mean = self.data_collection.get_mean(aggreg_type=self.aggreg_type)
df_mean = self.add_tense_data(df_mean)
if comparison_type=='single':
type_="bar chart"
else:
type_="histogram_&bar_&box"
# Layout for multiple x-axes if comparing times
layout = go.Layout(
xaxis=go.layout.XAxis(tickformat="%B"), template="plotly_dark"
)
if "time" in comparison_type:
for i in range(1, self.NB_PARAM_ELM):
layout[f"xaxis{i + 1}"] = go.layout.XAxis(
overlaying="x", side="top", showticklabels=False
)
bar_styles = {
"past": {"color": self.colors, "pattern_shape": "", "opacity": 1.0, "pattern_fillcolor": "white"},
"projection": {"color": self.colors, "pattern_shape": "/", "opacity": 0.8, "pattern_fillcolor": "white"}}
fig = go.Figure(layout=layout)
# Add traces for the bar chart
for i in range(self.NB_PARAM_ELM):
for tense in ["past", "projection"]:
mask = (df_mean[f"tense_collection_{i + 1}"] == tense)
if mask.sum() > 0:
fig.add_trace(
go.Bar(
x=df_mean.loc[mask, f"time_collection_{i + 1}"].dropna(),
y=df_mean.loc[mask, f"values_collection_{i + 1}"].dropna(),
name=f"{self.__get_chart_name(type_=type_, idx=i)} ({tense.capitalize()})",
marker=dict(
color=bar_styles[tense]["color"][i],
pattern=dict(shape=bar_styles[tense]["pattern_shape"],
fgcolor=bar_styles[tense]["pattern_fillcolor"]),
),
opacity=bar_styles[tense]["opacity"],
xaxis=f"x{i + 1}" if "time" in comparison_type and i > 0 else "x",
)
)
# Assigner la figure finale
self.fig = fig
[docs]
@crossing_point("""Generate the distribution""")
def generate_distribution(self, comparison_type="single"):
df = self.data_collection.get_values()
df = self.add_tense_data(df)
if comparison_type=='single':
type_="bar chart"
else:
type_="histogram_&bar_&box"
dist_styles = {
"past": {"color": self.colors, "opacity": 1.0, "histnorm": "probability", "name": "Past Distribution"},
"projection": {"color": self.colors, "opacity": 0.8, "histnorm": "probability", "name": "Projection Distribution"}}
# for i in range(self.NB_PARAM_ELM):
# for tense in ["past", "projection"]:
# mask = (df[f"tense_collection_{i + 1}"] == tense)
# if mask.sum() > 0: # Vérifie si des données existent pour ce tense
# self.fig.add_trace(
# go.Histogram(
# x=df.loc[mask, f"values_collection_{i + 1}"].dropna(),
# marker=dict(color=dist_styles[tense]["color"][i]),
# opacity=dist_styles[tense]["opacity"],
# histnorm=dist_styles[tense]["histnorm"], # Normaliser selon les besoins (probabilité ici)
# name=f"{self.climate_variable.capitalize()} ({tense.capitalize()})",
# )
# )
for i in range(self.NB_PARAM_ELM):
self.fig.add_trace(
go.Histogram(
x=df[f"values_collection_{i + 1}"].dropna(),
marker_color=self.colors[i],
name=self.__get_chart_name(type_=type_, idx=i),
)
)
[docs]
@crossing_point("""Generate the warming stripes""")
def generate_warming_stripes(self, comparison_type="single"):
if self.var_name == 'tp':
df_mean = self.data_collection.get_sum(aggreg_type=self.aggreg_type)
else:
df_mean = self.data_collection.get_mean(aggreg_type=self.aggreg_type)
df_mean = self.add_tense_data(df_mean)
if comparison_type=="single":
x, y = df_mean['time_collection_1'], df_mean['values_collection_1']
average_ref = self.get_mean_val_on_ref_period()[0]
y = self.compute_diff_ref(x, y, ref=average_ref, mode='simple')
self.__make_warming_stripes(x, [y])
else:
x_list = [df_mean[f"time_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)]
y_mean_list = [df_mean[f"values_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)]
average_ref = self.get_mean_val_on_ref_period()
if comparison_type == "time":
average_ref = average_ref * len(x_list)
y_mean_list = self.compute_diff_ref(x_list, y_mean_list, ref=average_ref, mode='multi')
self.__make_subplots_warming_stripes(
comparison=comparison_type, X=x_list, Y=y_mean_list)
[docs]
@crossing_point("""Generate the box plot""")
def generate_boxplot(self, comparison_type="single"):
df_grps = self.data_collection.get_groupe_values(aggreg_type=self.aggreg_type)
x = [df_grps[f"time_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)]
y = [df_grps[f"values_collection_{k + 1}"].dropna() for k in range(self.NB_PARAM_ELM)]
if comparison_type=='single':
type_="bar chart"
else:
type_="histogram_&bar_&box"
if comparison_type in ['time', 'location_time']:
layout = go.Layout(
xaxis=go.layout.XAxis(tickformat="%B"), template="plotly_dark"
)
for i in range(1, self.NB_PARAM_ELM):
layout[f"xaxis{i + 1}"] = go.layout.XAxis(
overlaying="x", side="top", showticklabels=False
)
self.fig = go.Figure(layout=layout)
for j in range(len(y)):
list_x = []
list_y = []
for label_orig, values in zip(x[j], y[j]):
label = pd.to_datetime(label_orig).strftime(
self.__get_strftime_from_agg()
)
x_vals = [label] * len(values)
list_x += x_vals
list_y += values
self.fig.add_trace(
go.Box(
x=list_x,
y=list_y,
name=self.__get_chart_name(type_=type_, idx=j),
line=dict(color=self.colors[j]),
)
)
[docs]
@crossing_point("""Generate the heatmap""")
def generate_map(self, comparison_type=None):
"""
Generate a map visualization using the data collected in the DataCollection.
Creates a chart using Altair, encoding latitude and longitude, and coloring points based on the variable
of interest. The resulting figure is stored in the `fig` attribute.
"""
data_list = self.data_collection.get_spacial_agreg_data()
for i in range(len(data_list)):
data = data_list[i]
df = data.to_dataframe().reset_index()
df = df[["latitude", "longitude", self.var_name]].dropna()
if len(df) > 5000:
df = df.sample(5000)
chart = (
alt.Chart(df)
.mark_circle(size=100)
.encode(
longitude="longitude:Q",
latitude="latitude:Q",
color=alt.Color(
f"{self.var_name}:Q",
scale=alt.Scale(scheme="plasma"),
title=f"{self.climate_variable.capitalize()} ({self.unit})",
legend=alt.Legend(
titleOrient="right", # or 'right' to position the title on the other side
orient="right", # position the legend on the right
# Rotate the legend title to make it vertical
),
),
tooltip=["latitude", "longitude", self.var_name],
)
.properties(
title=tools.translate_from_en(
self.__get_chart_name(type_="heatmap", idx=i), tools.langage_to_iso(self.langage)
)
)
)
if i == 0:
chart_all = chart.project('equalEarth')
else:
chart = chart.project('equalEarth', precision=.707)
chart_all = alt.hconcat(chart_all, chart) # TODO pas défini dans else
self.fig = chart_all
# ---------------------------------------------------------------------------------
# --------------------------------- Graph customization ---------------------------
[docs]
def get_data_source(self):
return self.__get_data_source()
def __get_data_source(self) -> str:
sources = []
for col in self.data_collection.consistent_collections:
for ds in col:
sources.append(ds.file.split('_')[0])
sources = list(set(sources))
data_source = ' & '.join(sources).replace('era5', 'ERA5').replace('dt', 'Climate DT')
return data_source
def __layout(self):
"""
Configures the layout of the graph, excluding specific graph types such as 'heatmap'.
This function sets up the graph title, axis titles, and annotations based on the graph type, climate variable,
and other properties. It adjusts the layout for specific graph types and ensures proper labels and tips are displayed
on the graph. If the graph type is in the excluded list (e.g., 'heatmap'), no layout is applied.
Returns
-------
None
"""
graph_type_excluded = ["heatmap"]
if self.graph_type in graph_type_excluded:
return None
sentence_period = self.__display_periode_for_title()
title = f"{self.climate_variable} {sentence_period}"
if self.graph_type == "histogram":
xaxis_title = f"{self.climate_variable.capitalize()} ({self.unit})"
else:
xaxis_title = "Period"
under_xaxis_title1 = "Source"
under_xaxis_title2 = "Used variable"
if self.graph_type == "histogram":
yaxis_title = "Number of observations"
else:
yaxis_title = f"{self.climate_variable.capitalize()} ({self.unit})"
tips = "Tips: see 'info' tab for other graph types"
graph_txt = GraphTexts(
self.graph_type,
title,
xaxis_title,
under_xaxis_title1,
under_xaxis_title2,
yaxis_title,
tips,
self.langage,
)
graph_txt = graph_txt.translate()
if "Digital Twin" not in self.__get_data_source():
xaxis_title = (
f"{graph_txt.x_axis_title} <br><sup>{graph_txt.under_x_axis_title1} : {self.__get_data_source()}"
f" | {graph_txt.under_x_axis_title2} : {self.var_name}</sup>"
)
else:
xaxis_title = (
f"{graph_txt.x_axis_title} <br><sup>{graph_txt.under_x_axis_title1} : {self.__get_data_source()}"
)
self.fig.update_layout(template="plotly_dark")
self.fig.update_layout(
title=graph_txt.graph_type + ' ' + graph_txt.title,
xaxis_title=xaxis_title,
yaxis_title=graph_txt.y_axis_title,
)
# Ajouter une annotation en bas à droite
x_pos, y_pos = self.get_xy_pos_note_under_graph()
self.fig.add_annotation(
text=f"<i>💡{graph_txt.tips}</i>",
xref="paper",
yref="paper",
x=x_pos,
y=y_pos,
showarrow=False,
font=dict(size=12, color="gray"),
xanchor="right",
)
if self.is_subplot:
NB_FIG = len(self.params_after_processing)
cols = math.ceil(math.sqrt(NB_FIG))
rows = math.ceil(NB_FIG / cols)
for i in range(1, len(self.fig.data) + 1):
row = (i - 1) // cols + 1
col = (i - 1) % cols + 1
if (row == rows) or NB_FIG == 2 or (NB_FIG % 3 == 0 and col == cols):
xaxis_title_ = xaxis_title
else:
xaxis_title_ = ""
self.fig.update_xaxes(title_text=xaxis_title_, row=row, col=col)
self.fig.update_yaxes(
title_text=graph_txt.y_axis_title, row=row, col=col
)
if self.graph_type in ["warming stripes"]:
self.fig.update_layout(
coloraxis={"colorscale": self.__get_colors_in_subplots()}
)
def __plot(self, show=False):
"""
Renders and stores the current figure based on the graph type.
This function saves the current plot in the session state for historical reference, then renders the figure
using either Altair (for 'heatmap' graph types) or Plotly (for all other graph types). If `show` is set to True,
the plot will also be displayed using Plotly's default rendering method.
Parameters
----------
show : bool, optional
If True, displays the plot using the default Plotly rendering. Defaults to False.
Returns
-------
None
"""
st.session_state.plots_history[
"msg_{}".format(len(st.session_state.messages) - 1)
] = self.fig
if self.graph_type == "heatmap":
st.altair_chart(self.fig, use_container_width=True)
else:
st.plotly_chart(self.fig, use_container_width=True)
if show:
self.fig.show()
def __get_context(self):
...
[docs]
@crossing_point("""This function finds the matching figure associated to the input parameters""")
def matching_case(self):
"""
Determines and generates the appropriate plot based on the comparison settings and graph type.
This function checks whether time or location comparisons are enabled, and then selects the appropriate
plot generation method based on the current `graph_type`. It handles different plot types including
line charts (with various subtypes depending on the time delta and aggregation type), bar charts,
and histograms.
Returns
-------
None
"""
function_mapping = {
"line chart": {
"raw": self.generate_lineplot,
"mean": self.generate_lineplot_mean,
"mean_rainbow": self.generate_rainbow,
},
"bar chart": self.generate_barplot,
"histogram": self.generate_distribution,
"warming stripes": self.generate_warming_stripes,
"box plot": self.generate_boxplot,
"heatmap": self.generate_map
}
plot_function = function_mapping.get(self.graph_type)
if isinstance(plot_function, dict):
if self.delta_time >= parameters.threshold_rainbow:
plot_function["mean_rainbow"](comparison_type=self.comparison_type)
elif self.aggreg_type == "raw":
plot_function["raw"](comparison_type=self.comparison_type)
else:
plot_function["mean"](comparison_type=self.comparison_type)
else :
plot_function(comparison_type=self.comparison_type)
def __display_periode_for_title(self):
if isinstance(self.params["starttime"], str):
sentence = f"from {self.params['starttime']} to {self.params['endtime']}"
else:
sentence = "for several periods"
return sentence
# ---------------------------------------------------------------------------------
# --------------------------------- Tools -----------------------------------------
[docs]
@crossing_point(
"""This function is checking if no rule has been broken (period limits, location limits). If something is not right, it raises warning""")
def checks(self):
"""
Perform validation checks on the graph generation parameters, including location and time bounds.
Returns
-------
bool
True if there are any errors in the parameters; otherwise, False.
"""
error = False
# ---- handle bounds and other errors ----
if self.comparison_time and self.comparison_loc:
if len(self.starttime) != len(self.location):
msg = "The number of locations does not correspond to the number of periods. The graph cannot be generated. This warning is printed because you are comparing several locations for several periods of time."
st.warning(msg)
error = True
if self.out_of_bounds:
if self.comparison_loc:
msg = f"Out of bounds: The following location(s) you provided is(are) not found in the database : {self.list_not_found_loc}"
else:
msg = f"Out of bounds: The location you provided ({self.location['location_name'].capitalize()}) is out of bounds. (Location not found in shapefile)"
st.warning(msg)
error = True
if self.time_out_of_bounds:
if not self.comparison_time:
msg = "Out of bounds: The period you provided is out of bounds."
else:
msg = "Out of bounds: Atleast one of the period you provided is out of bounds."
st.warning(msg)
error = True
if self.heatmap_not_supported:
msg = "Warning: the heatmap is not currently supported due to the small number of points available in a city."
st.warning(msg)
error = True
return error
[docs]
def get_xy_pos_note_under_graph(self):
"""
Determine the position of notes below the graph based on graph type and aggregation type.
Returns
-------
tuple
A tuple containing the x and y positions for the note.
"""
x, y = 1, 1
if self.graph_type == "line chart":
if self.aggreg_type == "raw":
if not self.comparison_time:
x, y = 1, -0.65
else:
x, y = 1.2, -0.3
if self.aggreg_type != "raw":
if not self.comparison_time:
x, y = 1.1, -0.30
else:
x, y = 1.2, -0.30
if self.delta_time >= parameters.threshold_rainbow:
x, y = 1.1, -0.3
if self.graph_type == "bar chart":
x, y = 1, -0.30
if self.graph_type in ["histogram", "warming stripes", "box plot"]:
x, y = 1, -0.25
if self.graph_type == "box plot":
if self.comparison_time:
x, y = 1, -0.35
if self.is_subplot:
x, y = 1, -0.50
if self.graph_type == 'line chart':
if len(self.location) == 2:
x, y = 1.1, -0.25
elif len(self.location) == 3:
x, y = 1.1, -0.1
elif len(self.location) == 4:
x, y = 1.1, -0.1
return x, y
def __get_strftime_from_agg(self):
"""
Get the date format string based on the aggregation type.
Returns
-------
str
The date format string corresponding to the aggregation type.
"""
if self.aggreg_type == "D":
strftime = "%d-%m-%Y"
elif self.aggreg_type == "MS":
strftime = "%m-%Y"
elif self.aggreg_type == "YS":
strftime = "%Y"
else:
strftime = "%d-%m-%Y"
return strftime
def __get_chart_name(self, type_, idx=0, year=0):
"""
Generate the title for the chart based on various parameters including type, index, and year.
Parameters
----------
type_ : str
The type of graph for which to generate the title.
idx : int, optional
The index of the location or time period. Defaults to 0.
year : int, optional
The year for which to generate the title. Defaults to 0.
Returns
-------
str
The generated title for the chart.
"""
if type_ is None:
type_ = self.graph_type
loc, starttime, endtime = self.location, self.starttime, self.endtime
if not isinstance(loc, list):
loc = [loc] * (idx + 1)
if not isinstance(starttime, list):
starttime = [starttime] * (idx + 1)
if not isinstance(endtime, list):
endtime = [endtime] * (idx + 1)
key = (self.comparison_time, self.comparison_loc, type_)
marker_name = {
# No time comparison, No loc comparison
(False, False, "line chart"): {
"raw": f"Raw {self.climate_variable}",
"mean": "Average {}.".format(self.climate_variable[0:4]),
"min_max": "Min-Max {}.".format(self.climate_variable[0:4]),
},
(False, False, "bar chart"): "Mean",
(False, False, "rainbow"): "{}. in {}".format(
self.climate_variable[:4], year
),
(False, False, "heatmap",
): f"Heatmap of average {self.climate_variable} from {tools.str_from_datetime(starttime[idx])} to {tools.str_from_datetime(endtime[idx])} in {loc[idx]['location_name']}",
# No time comparison, loc comparison
(False, True, "line chart"): {
"raw": loc[idx]["location_name"],
"mean": "Average {}.-{}".format(
self.climate_variable[0:4], loc[idx]["location_name"]
),
"min_max": "Min-Max {}.-{}".format(
self.climate_variable[0:4], loc[idx]["location_name"]
),
},
(False, True, "histogram_&bar_&box"): loc[idx]["location_name"],
(False, True, "rainbow"): "{}. in {}".format(
self.climate_variable[:4], year
),
(False, True, "heatmap"): f"Average {self.climate_variable} in {loc[idx]['location_name']}",
# Time comparison, no loc comparison
(True, False, "line chart"): {
"raw": starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
"mean": "Average {}.{}".format(
self.climate_variable[0:4],
starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
),
"min_max": f"Min-Max {starttime[idx].strftime('%m/%d/%Y') + ' ➡ ' + endtime[idx].strftime('%m/%d/%Y')}",
},
(True, False, "histogram_&bar_&box"): starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
(True, False, "rainbow"): "{}. in {}".format(
self.climate_variable[:4], year
),
(True, False, "heatmap"): starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
# Time comparison, loc comparison
(True, True, "line chart"): {
"raw": loc[idx]["location_name"]
+ " : "
+ starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
"mean": loc[idx]["location_name"]
+ " : "
+ "Average {}.{}".format(
self.climate_variable[0:4],
starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
),
"min_max": loc[idx]["location_name"]
+ " : "
+ f"Min-Max {starttime[idx].strftime('%m/%d/%Y') + ' ➡ ' + endtime[idx].strftime('%m/%d/%Y')}",
},
(True, True, "histogram_&bar_&box"): loc[idx]["location_name"]
+ " : "
+ starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
(True, True, "rainbow"): "{}. in {}".format(
self.climate_variable[:4], year
),
(True, True, "heatmap"): loc[idx]["location_name"]
+ " : "
+ starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
}
if self.climate_variable == 'precipitation':
if key == (False, False, "line chart"):
marker_name[key] = {
"raw": f"Raw {self.climate_variable}",
"mean": "Total {}.".format(self.climate_variable[0:4]),
"min_max": "",
}
elif key == (False, True, "line chart"):
marker_name[key] = {
"raw": loc[idx]["location_name"],
"mean": "Total {}.-{}".format(
self.climate_variable[0:4], loc[idx]["location_name"]),
"min_max": ""}
elif key == (True, False, 'line chart'):
marker_name[key] = {
"raw": starttime[idx].strftime("%m/%d/%Y")
+ " ➡ " + endtime[idx].strftime("%m/%d/%Y"),
"mean": "Total {}.{}".format(
self.climate_variable[0:4],
starttime[idx].strftime("%m/%d/%Y")
+ " ➡ " + endtime[idx].strftime("%m/%d/%Y"), ),
"min_max": "",
}
elif key == (True, True, 'line chart'):
marker_name[key] = {
"raw": loc[idx]["location_name"]
+ " : " + starttime[idx].strftime("%m/%d/%Y") + " ➡ "
+ endtime[idx].strftime("%m/%d/%Y"),
"mean": loc[idx]["location_name"]
+ " : " + "Total {}.{}".format(
self.climate_variable[0:4],
starttime[idx].strftime("%m/%d/%Y")
+ " ➡ " + endtime[idx].strftime("%m/%d/%Y"),
),
"min_max": ""}
return marker_name[key]
def __make_subplots_fig(self, comparison):
"""
Create subplots for the figure based on the specified comparison type (location, time, or both).
Parameters
----------
comparison : str
The type of comparison to use for creating subplots.
Returns
-------
tuple
A tuple containing the number of figures, number of columns, and number of rows for subplots.
"""
self.is_subplot = True
if comparison == "location":
sub_titles = self.params["location"]
elif comparison == "time":
starttime, endtime = self.starttime, self.endtime
sub_titles = [
starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y")
for idx in range(len(starttime))
]
elif comparison == "location_time":
location, starttime, endtime = (
self.params["location"],
self.starttime,
self.endtime,
)
sub_titles = [
location[idx]
+ " : "
+ starttime[idx].strftime("%m/%d/%Y")
+ " ➡ "
+ endtime[idx].strftime("%m/%d/%Y")
for idx in range(len(starttime))
]
NB_FIG = len(self.params_after_processing)
cols = math.ceil(math.sqrt(NB_FIG))
rows = math.ceil(NB_FIG / cols)
self.fig = make_subplots(rows=rows, cols=cols, subplot_titles=sub_titles)
return NB_FIG, cols, rows
def __make_subplots_warming_stripes(self, comparison, X, Y):
"""
Create subplots with warming stripes for the specified comparison, using the provided X and Y data.
Parameters
----------
comparison : str
The type of comparison to use for creating subplots.
X : list
The X-axis data for the plots.
Y : list
The Y-axis data for the plots.
"""
NB_FIG, cols, rows = self.__make_subplots_fig(comparison)
for i in range(NB_FIG):
x = X[i]
y_mean = Y[i]
temp_fig = px.imshow(
[y_mean],
aspect="auto",
color_continuous_scale="RdBu_r",
origin="lower",
labels={
"color": f"{self.climate_variable.capitalize()} ({self.unit})"
},
x=x,
)
for trace in temp_fig.data:
self.fig.add_trace(trace, row=(i // cols) + 1, col=(i % cols) + 1)
ct = self.cross_tenses
if isinstance(ct, bool): ct = [ct] * self.NB_PARAM_ELM
if ct[i] and self.NB_PARAM_ELM==2:
self.fig.add_vline(
x=parameters.time_max_data,
line_width=3,
line_dash="dash",
line_color="green",
row=(i // cols) + 1,
col=(i % cols) + 1,
)
self.fig.add_annotation(
x=parameters.time_max_data - pd.Timedelta(round(0.03 * self.delta_time_multi[i]), unit="D"),
y=0.0,
text="<i>Start of projections</i>",
textangle=270,
font=dict(color="green", size=14),
showarrow=False,
xref="x",
yref="paper",
row=(i // cols) + 1,
col=(i % cols) + 1,
)
self.fig.update_xaxes(
title_text="X-Axis Label",
row=(i // cols) + 1,
col=(i % cols) + 1,
)
self.fig.update_yaxes(
title_text="Y-Axis Label",
row=(i // cols) + 1,
col=(i % cols) + 1,
)
self.fig.update_yaxes(
visible=False,
showticklabels=False,
row=(i // cols) + 1,
col=(i % cols) + 1,
)
self.fig.update_layout(
coloraxis=dict(
colorscale="RdBu_r",
colorbar=dict(
title=f"Delta {self.climate_variable.capitalize()} ({self.unit})",
)))
def __make_warming_stripes(self, x, y):
self.fig = px.imshow(
y,
aspect="auto",
color_continuous_scale="RdBu_r",
origin="lower",
labels={"color": f"Delta {self.climate_variable.capitalize()} ({self.unit})"},
x=x,
)
if self.cross_tenses:
color = "green"
self.fig.add_vline(
x=parameters.time_max_data,
line_width=3,
line_dash="dash",
line_color=color,
)
self.fig.add_annotation(
x=parameters.time_max_data - pd.Timedelta(round(0.03 * self.delta_time), unit="D"), # Position en X
y=0.5,
text="<i>Start of projections</i>",
textangle=270,
font=dict(color=color, size=14),
showarrow=False,
xref="x",
yref="paper",
)
self.fig.update_layout(yaxis={"visible": False, "showticklabels": False})
def _get_interp_colors(self, years):
"""
Generate a gradient of colors for the specified years.
Parameters
----------
years : list
A list of years for which to generate colors.
Returns
-------
list
A list of RGBA color strings representing the gradient for the given years.
"""
# Créer un dégradé de couleurs (ici du gris foncé au gris clair avec
# des nuances de gris-vert pour plus de distinction)
norm_indices = np.linspace(0, 1, len(years))
start_color = (45, 65, 65) # Gris-bleu-vert foncé
mid_inf_color = (115, 120, 115) # Gris-vert moyen
mid_sup_color = (200, 200, 200) # Gris moyen
end_color = (255, 100, 100) # Rose
colors = []
for norm in norm_indices:
if norm <= 0.45:
# Interpolation du gris-bleu-vert moyen au gris-vert moyen
factor = norm / 0.45
color = tools.interpolate_color(start_color, mid_inf_color, factor)
elif 0.45 < norm < 0.85:
# Interpolation du gris-vert moyen au gris moyen
factor = (norm - 0.45) / 0.4
color = tools.interpolate_color(mid_inf_color, mid_sup_color, factor)
else:
# Interpolation du gris moyen au rose
factor = (norm - 0.85) / 0.15
color = tools.interpolate_color(mid_sup_color, end_color, factor)
colors.append(f"rgba({color[0]}, {color[1]}, {color[2]}, 1)")
colors[-1] = "rgba(250, 0, 0, 1)" # Dernière couleur (rouge)
return colors
[docs]
def get_interp_colors_anomaly(self, years):
colors = ['rgba(150, 150, 150, 1)'] * (len(years)) # Ancienne couleur (gris)
colors[-2] = "rgba(255, 153, 51, 1)" # Avant-Dernière couleur (orange)
colors[-1] = "rgba(204, 0, 0, 1)" # Dernière couleur (rouge)
return colors
def __make_subplots_rainbow(self, comparison, data):
"""
Create subplots representing a rainbow visualization based on the provided data.
Parameters
----------
comparison : str
The type of comparison to use for creating subplots.
data : list
The dataset to visualize in the subplots.
"""
NB_FIG, cols, rows = self.__make_subplots_fig(comparison)
time_step = "MS"
for idx, dataset in enumerate(data):
row = (idx // cols) + 1
col = (idx % cols) + 1
if isinstance(self.starttime, list):
years = [
y
for y in range(self.starttime[idx].year, self.endtime[idx].year + 1)
]
colors = self._get_interp_colors(years)
else:
years = [y for y in range(self.starttime.year, self.endtime.year + 1)]
colors = self._get_interp_colors(years)
for year, color in zip(years, colors):
data_mean = dataset.sel(time=str(year))
months = tools.get_month_names()
x = months
y_mean = data_mean.values
if len(y_mean) != 12:
N_months = [i for i in range(1, 13)]
dict_months = dict(zip(N_months, months))
present_months = [
pd.Timestamp(month).month for month in data_mean.time.values
]
present_months = [dict_months[month] for month in present_months]
mois_y = dict(zip(present_months, list(y_mean)))
new_y = []
for mois in months:
if mois in present_months:
new_y.append(mois_y[mois])
else:
new_y.append(None)
y_mean = new_y
line_width = 1
if year == years[-1]:
line_width = 2
dash, ext_name = "solid", ''
if year >= parameters.time_max_data.year :
dash = "dot"
ext_name = ' (projection)'
self.fig.add_trace(
go.Scatter(
x=x,
y=y_mean,
mode="lines",
name=self.__get_chart_name(type_="rainbow", year=year) + ext_name,
line=dict(color=color, width=line_width, dash=dash),
),
row=row,
col=col,
)
def __make_rainbow(self, data):
years = [y for y in range(self.starttime.year, self.endtime.year + 1)]
colors = self._get_interp_colors(years)
for year, color in zip(years, colors):
print(f'\rProcessing year: {year}', end='', flush=True)
data_mean = data.sel(time=str(year))
months = tools.get_month_names()
x = months
y_mean = data_mean.values
if len(y_mean) != 12:
N_months = [i for i in range(1, 13)]
dict_months = dict(zip(N_months, months))
present_months = [
pd.Timestamp(month).month for month in data_mean.time.values
]
present_months = [dict_months[month] for month in present_months]
mois_y = dict(zip(present_months, list(y_mean)))
new_y = []
for mois in months:
if mois in present_months:
new_y.append(mois_y[mois])
else:
new_y.append(None)
y_mean = new_y
# ------ display graph ------
line_width = 1
if year == years[-1]:
line_width = 2
dash, ext_name = "solid", ''
if year >= parameters.time_max_data.year :
dash = "dot"
ext_name = ' (projection)'
self.fig.add_trace(
go.Scatter(
x=x,
y=y_mean,
mode="lines",
name=self.__get_chart_name(type_="rainbow", year=year) + ext_name,
line=dict(color=color, width=line_width, dash=dash),
)
)
def __get_colors_in_subplots(self):
colors = {"warming stripes": "RdBu_r"}
return colors[self.graph_type]
[docs]
def graph_elm_for_translate(self, cat, lang):
"""
Translate graph elements based on category and language.
Parameters
----------
cat : str
The category of the graph element to translate.
lang : str
The target language for the translation.
Returns
-------
str
The translated graph element based on the provided category and language.
"""
if cat == "graph_type":
if lang == "English":
return self.graph_type
dict_graph_types = {
"line chart": "curve graph",
"bar chart": "bar graph",
"histogram": "histogram",
"box plot": "box graphic",
"heatmap": "heatmap",
"warming stripes": "heatmap",
}
return dict_graph_types[self.graph_type]
[docs]
def get_mean_val_on_ref_period(self):
"""
Calculate the average values of the dataset over a specified reference period (1971-2000).
This function generates a DataCollection for the specified reference period, extracts the required data,
and computes average values based on the aggregation type (yearly, monthly, or daily). The function
returns a list of means corresponding to each aggregation level.
Returns
-------
list
A list of mean values for each aggregation type within the reference period:
- If the aggregation type is yearly, the list contains mean values per year.
- If monthly, the list contains mean values per month across years.
- If daily, the list contains mean values for each day across years.
"""
logger.info("Computing average on reference period : 1971-2000")
params = self.params.copy()
params['starttime'] = parameters.start_date_ref_period
params['endtime'] = parameters.end_date_ref_period
gg = ServiceGeneratePlotlyGraph(params, 'English', self.user_type)
list_means = []
if gg.aggreg_type == 'YS':
df_mean = gg.data_collection.get_mean(gg.aggreg_type)
for col in df_mean.columns[1:]:
if 'values' in col:
list_means.append(df_mean[col].mean())
elif gg.aggreg_type == 'MS':
df_mean = gg.data_collection.get_mean(gg.aggreg_type)
for i in range(len(gg.params_after_processing)):
df = df_mean[[f"time_collection_{i + 1}", f"values_collection_{i + 1}"]]
df['Month'] = df[f"time_collection_{i + 1}"].dt.month
means = df.groupby('Month')[f"values_collection_{i + 1}"].mean()
list_means.append(means)
elif gg.aggreg_type == 'D':
df_mean = gg.data_collection.get_mean(gg.aggreg_type)
for i in range(len(gg.params_after_processing)):
df = df_mean[[f"time_collection_{i + 1}", f"values_collection_{i + 1}"]]
df['day'] = df[f"time_collection_{i + 1}"].dt.strftime("%m-%d")
means = df.groupby('day')[f"values_collection_{i + 1}"].mean()
list_means.append(means)
else:
list_means.append(0)
return list_means
[docs]
def compute_diff_ref(self, x, y, ref, mode):
"""
Compute the difference between observed values and a reference period, returning the anomalies based on the specified mode.
This function calculates the anomaly of values relative to a given reference period. It supports both 'simple' and 'multi'
modes of calculation:
- In 'simple' mode, it computes anomalies for a single reference dataset.
- In 'multi' mode, it computes anomalies across multiple reference datasets.
Parameters
----------
x : list
List of dates corresponding to the observed values.
y : list
List of observed values for each date.
ref : list
Reference values used to compute the anomalies.
mode : str
Calculation mode ('simple' for single reference, 'multi' for multiple references).
Returns
-------
list
A list of computed anomalies, with each anomaly representing the difference between observed values and reference values:
- If aggregation is yearly ('YS'), returns yearly anomalies.
- If aggregation is monthly ('MS'), returns monthly anomalies.
"""
if mode == 'simple':
if self.aggreg_type == 'YS':
y = [val - ref for val in y]
return y
elif self.aggreg_type == 'MS':
df = pd.DataFrame({'Date': x, 'Values': y})
df['Month'] = df['Date'].dt.month
df['Anomaly'] = df['Values'] - df['Month'].map(ref)
return df['Anomaly'].values
elif mode == 'multi':
if self.aggreg_type == 'YS':
list_means = []
for i in range(len(y)):
anomalies = y[i] - ref[i]
list_means.append(anomalies)
return list_means
elif self.aggreg_type == 'MS':
list_means = []
for i in range(len(x)):
df = pd.DataFrame({'Date': x[i], 'Values': y[i]})
df['Month'] = df['Date'].dt.month
df['Anomaly'] = df['Values'] - df['Month'].map(ref[i])
list_means.append(df['Anomaly'].values)
return list_means
[docs]
def add_tense_data(self, df):
for i in range(self.NB_PARAM_ELM):
df[f"tense_collection_{i+1}"] = df[f"time_collection_{i+1}"].apply(
lambda x: "past" if x < parameters.time_max_data else "projection")
return df
@property
def colors(self):
return tools.get_10_colors()
@property
def colors_gray(self):
return tools.get_10_colors_gray()
if __name__ == "__main__":
params_ = {
"starttime": ['2010-01-01', '2011-10-01'], # '2020-01-01' - ['2023-03-01', '1990-01-01'],
"endtime": ['2020-12-01', '2021-09-01'], # '2020-12-31' - ['2024-11-01', '1990-09-01'],
"location": 'Rome', # 'Rome' - ['Milano', 'Napoli'],
"climate_variable": "temperature",
"graph_type": "warming stripes", # raw, line chart, histogram, bar chart, box plot, warming stripes, heatmap
"aggreg_type": "raw",
"anomaly": False,
}
langage = "French"
user_type = "normal"
st.session_state.plots_history = {}
st.session_state.messages = {}
graph_generator = ServiceGeneratePlotlyGraph(params_, langage, user_type)
graph_generator.generate(show=True)
self = graph_generator
fig = self.fig
# fig = st.session_state.plots_history["msg_-1"]