# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import json
from abc import abstractmethod
from datetime import timedelta
import argostranslate.translate as translator
import geopandas as gpd
import pandas as pd
from loguru import logger
import diva.tools as tools
from diva import parameters
from diva.data.dataset import DataCollection
from diva.data.shapefile_masker import ShapefileReader
from diva.logging.logger import crossing_point
[docs]
class GraphTexts:
"""
A class to manage the textual content of a graph, including titles and axis labels, with the ability
to translate the text into English.
Attributes
----------
title : str
The title of the graph.
x_axis_title : str
The title of the x-axis.
under_x_axis_title1 : str
The first subtitle or description below the x-axis.
under_x_axis_title2 : str
The second subtitle or description below the x-axis.
y_axis_title : str
The title of the y-axis.
tips : str
Additional tips or instructions related to the graph.
langage : str
The language in which the graph texts are written.
"""
def __init__(
self,
graph_type: str,
title: str,
x_axis_title: str,
under_x_axis_title1,
under_x_axis_title2,
y_axis_title: str,
tips,
langage,
):
self.graph_type = graph_type
self.title = title
self.x_axis_title = x_axis_title
self.under_x_axis_title1 = under_x_axis_title1
self.under_x_axis_title2 = under_x_axis_title2
self.y_axis_title = y_axis_title
self.tips = tips
self.langage = langage
[docs]
@crossing_point("""Get all textual content of a graph into a list""")
def textual_content(self):
"""
Retrieve all textual content of the graph as a list.
Returns
-------
list
A list containing the title, axis titles, subtitles, and tips for the graph.
"""
return [
self.title,
self.x_axis_title,
self.under_x_axis_title1,
self.under_x_axis_title2,
self.y_axis_title,
self.tips,
]
[docs]
@crossing_point("""Translate the textual content of a graph""")
def translate(self):
"""
Translate the textual content of the graph from English if the language is not English.
Returns
-------
GraphTexts
A new GraphTexts object with the translated content, or the original object if the language is English.
"""
source_lang = tools.langage_to_iso(self.langage)
if source_lang != "en":
with open(parameters.graph_type_langage_mapping_from_en) as f:
graph_types_mapping = json.load(f)
if source_lang in graph_types_mapping.keys():
translated_graph_type = graph_types_mapping[source_lang][self.graph_type]
else:
translated_graph_type = translator.translate(
self.graph_type, "en", source_lang)
return GraphTexts(
*([translated_graph_type] + [
translator.translate(i, "en", source_lang)
for i in self.textual_content()
]
+ [self.langage])
)
else:
self.graph_type = self.graph_type.capitalize() + " of"
return self
[docs]
class IGraphGenerator:
def __init__(self, params, langage, user_type):
"""
Initializes the ServiceGeneratePlotlyGraph class with the parameters for graph generation.
Parameters
----------
params : dict
A dictionary containing parameters for graph generation, including:
- Climate variable (e.g., temperature, pressure, wind, precipitation)
- Location (e.g., country, city, address)
- Start time and end time (dates)
- Type of graph (e.g., line chart, bar chart, histogram)
- Type of aggregation (e.g., mean, raw)
This constructor also sets up shapefiles for geographical data and processes the input parameters.
"""
logger.info(f"Received params: {params}")
params['anomaly'] = False
logger.warning("Anomaly mode is disabled !")
print("PARAMS: ", params)
self.user_type = user_type
self.__params = params.copy()
self.__langage = langage
self.__climate_variable = params["climate_variable"]
self.__location = params["location"]
self.__starttime = params["starttime"]
self.__endtime = params["endtime"]
self.__graph_type = params["graph_type"]
self.__aggreg_type = params["aggreg_type"]
self.__data = None
self.__comparison_loc = False
self.__comparison_time = False
self.comparison_type = 'single'
self.__out_of_bounds = False
self.__time_out_of_bounds = False
self.__heatmap_not_supported = False
self.__list_not_found_loc = []
if self.__graph_type == "Unknown":
if self.__climate_variable == "precipitation":
self.__graph_type = "bar chart"
else:
self.__graph_type = "line chart"
if self.__aggreg_type == "Unknown":
self.__aggreg_type = "raw"
self.__graph_context = "[UNAVAILABLE]"
self.is_subplot = False
self.__path_var = None
self.unit = None
self.var_name = None
self.__shapefile_europe = gpd.read_file(
parameters.path_data + parameters.shapefile_europe
)
self.__shapefile_europe_cities = gpd.read_file(
parameters.path_data + parameters.shapefile_europe_cities
)
self.process_params()
if not self.__out_of_bounds and not self.__time_out_of_bounds:
self.data_collection = self.initialise_data_collection()
[docs]
@abstractmethod
def generate(self):
...
[docs]
def get_collection(self):
if self.user_type == "normal":
cache_b_collections = parameters.cache_b_collections_normal
elif self.user_type == "granted":
cache_b_collections = parameters.cache_b_collections_granted
else:
raise ValueError("Invalid user type. Must be 'normal' or 'granted'.")
return cache_b_collections
[docs]
@crossing_point(
"""Initialize the data collection class and calls main methods for the graph package""")
def initialise_data_collection(self):
# prepare input parameters
time_intervals = [[param['starttime'], param['endtime']]
for param in self.params_after_processing]
locs = [param['location'] for param in self.params_after_processing]
if self.var_name == "tp":
agreg_func = 'sum'
else:
agreg_func = 'mean'
# initialise the datacollection class
cache_b_collections = self.get_collection()
print('cache_b_collections:', cache_b_collections)
dc = DataCollection(cache_b_collections, self.var_name)
dc = dc.sample_time(time_intervals)
dc = dc.apply_masks(locs)
if self.params_after_processing[0]['graph_type'] != 'heatmap':
dc = dc.spatial_aggregation()
return dc
# ---------------------------- Processing ----------------------------
[docs]
@crossing_point(
"""Process the parameters : convert to timestamp dates, get info about locations and perform out-of-bounds checks""")
def process_params(self):
"""
Processes input parameters for graph generation:
- Determines the variable name, path to data, and unit based on the climate variable.
- Sets start and end times as Python datetime objects and handles out-of-bounds dates.
- Processes location information and checks for validity.
- Determines the aggregation type.
Returns
-------
None
"""
# ----- info about variable -----
if self.__climate_variable == "temperature":
self.var_name = "t2m"
self.unit = "°C"
elif self.__climate_variable == "wind":
self.var_name = "fg10"
self.unit = "km/h"
elif self.__climate_variable == "precipitation":
self.var_name = "tp"
self.unit = "mm"
elif self.__climate_variable == "pressure":
self.var_name = "sp"
self.unit = "kPa"
# ----- process period -----
if not isinstance(self.__starttime, list):
try:
self.__starttime = pd.to_datetime(
self.__starttime, format="%Y-%m-%d")
self.__endtime = pd.to_datetime(
self.__endtime, format="%Y-%m-%d")
self.__starttime, self.__endtime = tools.put_in_order_dates(
self.__starttime, self.__endtime
)
self.__endtime = self.__endtime + timedelta(hours=23.99)
# handle out of bounds
if (
self.__starttime < parameters.time_min_data
and self.__endtime < parameters.time_min_data
):
self.__time_out_of_bounds = True
if (
self.__starttime > parameters.time_max_data_proj
and self.__endtime > parameters.time_max_data_proj
):
self.__time_out_of_bounds = True
# if start time is before minimum time & end after max
if self.__starttime < parameters.time_min_data:
self.__starttime = parameters.time_min_data
if self.__endtime > parameters.time_max_data_proj:
self.__endtime = parameters.time_max_data_proj
except pd.errors.OutOfBoundsDatetime:
self.__time_out_of_bounds = True
return None
elif isinstance(self.__starttime, list):
try:
self.__comparison_time = True
self.__starttime = [
pd.to_datetime(date, format="%Y-%m-%d") for date in self.__starttime
]
self.__endtime = [
pd.to_datetime(date, format="%Y-%m-%d") for date in self.__endtime
]
for i in range(len(self.__starttime)):
start, end = self.__starttime[i], self.__endtime[i]
start, end = tools.put_in_order_dates(start, end)
end = end + timedelta(hours=23.99)
self.__starttime[i], self.__endtime[i] = start, end
# handle out of bounds
for i in range(len(self.__starttime)):
start, end = self.__starttime[i], self.__endtime[i]
if start < parameters.time_min_data and end < parameters.time_min_data:
self.__time_out_of_bounds = True
if (
start > parameters.time_max_data_proj
and end > parameters.time_max_data_proj
):
self.__time_out_of_bounds = True
# if start time is before minimum time
for i in range(len(self.__starttime)):
start, end = self.__starttime[i], self.__endtime[i]
if start < parameters.time_min_data:
self.__starttime[i] = parameters.time_min_data
if end > parameters.time_max_data_proj:
self.__endtime[i] = parameters.time_max_data_proj
except pd.errors.OutOfBoundsDatetime:
self.__time_out_of_bounds = True
return None
if self.__time_out_of_bounds:
return None
# ----- process location -----
if isinstance(self.__location, str):
self.__location = tools.get_infos_from_location(self.__location)
print("location info:", self.__location)
if self.__location["addresstype"] not in parameters.large_regions_type:
if (
self.__location["location_name_orig"].replace(
" - ", "-")
not in self.__shapefile_europe_cities.LAU_NAME.values
):
self.__out_of_bounds = True
if self.__graph_type == "heatmap":
self.__heatmap_not_supported = True
else:
if (
self.__location["location_name"]
not in ShapefileReader.list_countries()
):
self.__out_of_bounds = True
elif isinstance(self.__location, list):
self.__comparison_loc = True
self.__location = [
tools.get_infos_from_location(loc) for loc in self.__location
]
for location in self.__location:
if location["addresstype"] not in parameters.large_regions_type:
if (
location["location_name_orig"].replace(" - ", "-")
not in self.__shapefile_europe_cities.LAU_NAME.values
):
self.__list_not_found_loc.append(
location["location_name"])
self.__out_of_bounds = True
else:
if (
location["location_name"]
not in ShapefileReader.list_countries()
):
self.__list_not_found_loc.append(
location["location_name"])
self.__out_of_bounds = True
# Get type of comparison:
if not self.__comparison_loc and not self.__comparison_time:
self.comparison_type = 'single'
elif self.__comparison_loc and not self.__comparison_time:
self.comparison_type = 'location'
elif not self.__comparison_loc and self.__comparison_time:
self.comparison_type = 'time'
else:
self.comparison_type = 'location_time'
# ----- process agg scale -----
if self.__aggreg_type == "year":
self.__aggreg_type = "YS"
elif self.__aggreg_type == "month":
self.__aggreg_type = "MS"
elif self.__aggreg_type == "week":
self.__aggreg_type = "W"
elif self.__aggreg_type == "day":
self.__aggreg_type = "D"
else:
self.__aggreg_type = "raw"
# Define aggreg for specific graphs
if self.__aggreg_type == "raw" and self.__graph_type in [
"bar chart",
"box plot"]:
self.__aggreg_type = parameters.defaut_time_aggreg_if_necessary
if self.__aggreg_type in ["raw", "W", "D"] and self.__graph_type == "warming stripes":
self.__aggreg_type = "YS"
if self.__graph_type == 'line chart' and self.delta_time >= parameters.threshold_rainbow:
if not self.params['anomaly']:
self.__aggreg_type = "MS"
elif self.params['anomaly']:
# can be changed to integrate other agreg.
self.__aggreg_type = "MS"
if self.__graph_type == 'heatmap':
self.__aggreg_type = 'average over the period'
if self.__graph_type == 'warming stripes' and self.__aggreg_type == 'YS':
if not isinstance(self.__starttime, list):
self.__starttime = pd.to_datetime(
str(self.__starttime.year) + "-01-01")
self.__endtime = pd.to_datetime(
str(self.__endtime.year) + "-12-31")
else:
for i in range(len(self.__starttime)):
self.__starttime[i] = pd.to_datetime(
str(self.__starttime[i].year) + "-01-01")
self.__endtime[i] = pd.to_datetime(
str(self.__endtime[i].year) + "-12-31")
@property
def params_after_processing(self):
if not self.comparison_loc and not self.comparison_time:
return [{
"starttime": self.starttime,
"endtime": self.endtime,
"location": self.location,
"climate_variable": self.var_name,
"graph_type": self.graph_type,
"aggreg_type": self.aggreg_type,
}]
if self.comparison_loc and not self.comparison_time:
res = []
for loc in self.location:
res.append({
"starttime": self.starttime,
"endtime": self.endtime,
"location": loc,
"climate_variable": self.var_name,
"graph_type": self.graph_type,
"aggreg_type": self.aggreg_type,
})
return res
if not self.comparison_loc and self.comparison_time:
res = []
for start, end in zip(self.starttime, self.endtime):
res.append({
"starttime": start,
"endtime": end,
"location": self.location,
"climate_variable": self.var_name,
"graph_type": self.graph_type,
"aggreg_type": self.aggreg_type,
})
return res
if self.comparison_loc and self.comparison_time:
res = []
for i in range(len(self.location)):
res.append({
"starttime": self.starttime[i],
"endtime": self.endtime[i],
"location": self.location[i],
"climate_variable": self.var_name,
"graph_type": self.graph_type,
"aggreg_type": self.aggreg_type,
})
return res
@property
def context(self):
return self.__graph_context
@property
def params(self):
return self.__params
@property
def langage(self):
return self.__langage
@property
def climate_variable(self):
return self.__climate_variable
@property
def location(self):
return self.__location
@property
def starttime(self):
return self.__starttime
@property
def endtime(self):
return self.__endtime
@property
def graph_type(self):
return self.__graph_type
@property
def aggreg_type(self):
return self.__aggreg_type
@property
def data(self):
return self.__data
@property
def comparison_loc(self):
return self.__comparison_loc
@property
def comparison_time(self):
return self.__comparison_time
@property
def out_of_bounds(self):
return self.__out_of_bounds
@property
def time_out_of_bounds(self):
return self.__time_out_of_bounds
@property
def heatmap_not_supported(self):
return self.__heatmap_not_supported
@property
def list_not_found_loc(self):
return self.__list_not_found_loc
@property
def graph_context(self):
return self.__graph_context
@property
def shapefile_europe(self):
return self.__shapefile_europe
@property
def shapefile_europe_cities(self):
return self.__shapefile_europe_cities
@property
def delta_time(self):
if not self.__time_out_of_bounds:
if not self.__comparison_time:
dt = (self.__endtime - self.__starttime).days
else:
dt = (self.__endtime[0] - self.__starttime[0]).days
return dt
return None
@property
def delta_time_multi(self):
if not self.__time_out_of_bounds:
if not self.__comparison_time:
dt = [(self.__endtime - self.__starttime).days] * self.NB_PARAM_ELM
else:
dt = [(endtime - starttime).days for (endtime, starttime) in zip(self.__endtime, self.__starttime)]
return dt
return None
@property
def NB_PARAM_ELM(self):
return len(self.params_after_processing)
@property
def cross_tenses(self):
if not self.comparison_time:
return self.starttime <= parameters.time_max_data <= self.endtime
return [self.starttime[i] <= parameters.time_max_data <= self.endtime[i] for i in range(self.NB_PARAM_ELM)]