Source code for diva.graphs.graph_generator

# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import json
from abc import abstractmethod
from datetime import timedelta

import argostranslate.translate as translator
import geopandas as gpd
import pandas as pd
from loguru import logger

import diva.tools as tools
from diva import parameters
from diva.data.dataset import DataCollection
from diva.data.shapefile_masker import ShapefileReader
from diva.logging.logger import crossing_point


[docs] class GraphTexts: """ A class to manage the textual content of a graph, including titles and axis labels, with the ability to translate the text into English. Attributes ---------- title : str The title of the graph. x_axis_title : str The title of the x-axis. under_x_axis_title1 : str The first subtitle or description below the x-axis. under_x_axis_title2 : str The second subtitle or description below the x-axis. y_axis_title : str The title of the y-axis. tips : str Additional tips or instructions related to the graph. langage : str The language in which the graph texts are written. """ def __init__( self, graph_type: str, title: str, x_axis_title: str, under_x_axis_title1, under_x_axis_title2, y_axis_title: str, tips, langage, ): self.graph_type = graph_type self.title = title self.x_axis_title = x_axis_title self.under_x_axis_title1 = under_x_axis_title1 self.under_x_axis_title2 = under_x_axis_title2 self.y_axis_title = y_axis_title self.tips = tips self.langage = langage
[docs] @crossing_point("""Get all textual content of a graph into a list""") def textual_content(self): """ Retrieve all textual content of the graph as a list. Returns ------- list A list containing the title, axis titles, subtitles, and tips for the graph. """ return [ self.title, self.x_axis_title, self.under_x_axis_title1, self.under_x_axis_title2, self.y_axis_title, self.tips, ]
[docs] @crossing_point("""Translate the textual content of a graph""") def translate(self): """ Translate the textual content of the graph from English if the language is not English. Returns ------- GraphTexts A new GraphTexts object with the translated content, or the original object if the language is English. """ source_lang = tools.langage_to_iso(self.langage) if source_lang != "en": with open(parameters.graph_type_langage_mapping_from_en) as f: graph_types_mapping = json.load(f) if source_lang in graph_types_mapping.keys(): translated_graph_type = graph_types_mapping[source_lang][self.graph_type] else: translated_graph_type = translator.translate( self.graph_type, "en", source_lang) return GraphTexts( *([translated_graph_type] + [ translator.translate(i, "en", source_lang) for i in self.textual_content() ] + [self.langage]) ) else: self.graph_type = self.graph_type.capitalize() + " of" return self
[docs] class IGraphGenerator: def __init__(self, params, langage, user_type): """ Initializes the ServiceGeneratePlotlyGraph class with the parameters for graph generation. Parameters ---------- params : dict A dictionary containing parameters for graph generation, including: - Climate variable (e.g., temperature, pressure, wind, precipitation) - Location (e.g., country, city, address) - Start time and end time (dates) - Type of graph (e.g., line chart, bar chart, histogram) - Type of aggregation (e.g., mean, raw) This constructor also sets up shapefiles for geographical data and processes the input parameters. """ logger.info(f"Received params: {params}") params['anomaly'] = False logger.warning("Anomaly mode is disabled !") print("PARAMS: ", params) self.user_type = user_type self.__params = params.copy() self.__langage = langage self.__climate_variable = params["climate_variable"] self.__location = params["location"] self.__starttime = params["starttime"] self.__endtime = params["endtime"] self.__graph_type = params["graph_type"] self.__aggreg_type = params["aggreg_type"] self.__data = None self.__comparison_loc = False self.__comparison_time = False self.comparison_type = 'single' self.__out_of_bounds = False self.__time_out_of_bounds = False self.__heatmap_not_supported = False self.__list_not_found_loc = [] if self.__graph_type == "Unknown": if self.__climate_variable == "precipitation": self.__graph_type = "bar chart" else: self.__graph_type = "line chart" if self.__aggreg_type == "Unknown": self.__aggreg_type = "raw" self.__graph_context = "[UNAVAILABLE]" self.is_subplot = False self.__path_var = None self.unit = None self.var_name = None self.__shapefile_europe = gpd.read_file( parameters.path_data + parameters.shapefile_europe ) self.__shapefile_europe_cities = gpd.read_file( parameters.path_data + parameters.shapefile_europe_cities ) self.process_params() if not self.__out_of_bounds and not self.__time_out_of_bounds: self.data_collection = self.initialise_data_collection()
[docs] @abstractmethod def generate(self): ...
[docs] def get_collection(self): if self.user_type == "normal": cache_b_collections = parameters.cache_b_collections_normal elif self.user_type == "granted": cache_b_collections = parameters.cache_b_collections_granted else: raise ValueError("Invalid user type. Must be 'normal' or 'granted'.") return cache_b_collections
[docs] @crossing_point( """Initialize the data collection class and calls main methods for the graph package""") def initialise_data_collection(self): # prepare input parameters time_intervals = [[param['starttime'], param['endtime']] for param in self.params_after_processing] locs = [param['location'] for param in self.params_after_processing] if self.var_name == "tp": agreg_func = 'sum' else: agreg_func = 'mean' # initialise the datacollection class cache_b_collections = self.get_collection() print('cache_b_collections:', cache_b_collections) dc = DataCollection(cache_b_collections, self.var_name) dc = dc.sample_time(time_intervals) dc = dc.apply_masks(locs) if self.params_after_processing[0]['graph_type'] != 'heatmap': dc = dc.spatial_aggregation() return dc
# ---------------------------- Processing ----------------------------
[docs] @crossing_point( """Process the parameters : convert to timestamp dates, get info about locations and perform out-of-bounds checks""") def process_params(self): """ Processes input parameters for graph generation: - Determines the variable name, path to data, and unit based on the climate variable. - Sets start and end times as Python datetime objects and handles out-of-bounds dates. - Processes location information and checks for validity. - Determines the aggregation type. Returns ------- None """ # ----- info about variable ----- if self.__climate_variable == "temperature": self.var_name = "t2m" self.unit = "°C" elif self.__climate_variable == "wind": self.var_name = "fg10" self.unit = "km/h" elif self.__climate_variable == "precipitation": self.var_name = "tp" self.unit = "mm" elif self.__climate_variable == "pressure": self.var_name = "sp" self.unit = "kPa" # ----- process period ----- if not isinstance(self.__starttime, list): try: self.__starttime = pd.to_datetime( self.__starttime, format="%Y-%m-%d") self.__endtime = pd.to_datetime( self.__endtime, format="%Y-%m-%d") self.__starttime, self.__endtime = tools.put_in_order_dates( self.__starttime, self.__endtime ) self.__endtime = self.__endtime + timedelta(hours=23.99) # handle out of bounds if ( self.__starttime < parameters.time_min_data and self.__endtime < parameters.time_min_data ): self.__time_out_of_bounds = True if ( self.__starttime > parameters.time_max_data_proj and self.__endtime > parameters.time_max_data_proj ): self.__time_out_of_bounds = True # if start time is before minimum time & end after max if self.__starttime < parameters.time_min_data: self.__starttime = parameters.time_min_data if self.__endtime > parameters.time_max_data_proj: self.__endtime = parameters.time_max_data_proj except pd.errors.OutOfBoundsDatetime: self.__time_out_of_bounds = True return None elif isinstance(self.__starttime, list): try: self.__comparison_time = True self.__starttime = [ pd.to_datetime(date, format="%Y-%m-%d") for date in self.__starttime ] self.__endtime = [ pd.to_datetime(date, format="%Y-%m-%d") for date in self.__endtime ] for i in range(len(self.__starttime)): start, end = self.__starttime[i], self.__endtime[i] start, end = tools.put_in_order_dates(start, end) end = end + timedelta(hours=23.99) self.__starttime[i], self.__endtime[i] = start, end # handle out of bounds for i in range(len(self.__starttime)): start, end = self.__starttime[i], self.__endtime[i] if start < parameters.time_min_data and end < parameters.time_min_data: self.__time_out_of_bounds = True if ( start > parameters.time_max_data_proj and end > parameters.time_max_data_proj ): self.__time_out_of_bounds = True # if start time is before minimum time for i in range(len(self.__starttime)): start, end = self.__starttime[i], self.__endtime[i] if start < parameters.time_min_data: self.__starttime[i] = parameters.time_min_data if end > parameters.time_max_data_proj: self.__endtime[i] = parameters.time_max_data_proj except pd.errors.OutOfBoundsDatetime: self.__time_out_of_bounds = True return None if self.__time_out_of_bounds: return None # ----- process location ----- if isinstance(self.__location, str): self.__location = tools.get_infos_from_location(self.__location) print("location info:", self.__location) if self.__location["addresstype"] not in parameters.large_regions_type: if ( self.__location["location_name_orig"].replace( " - ", "-") not in self.__shapefile_europe_cities.LAU_NAME.values ): self.__out_of_bounds = True if self.__graph_type == "heatmap": self.__heatmap_not_supported = True else: if ( self.__location["location_name"] not in ShapefileReader.list_countries() ): self.__out_of_bounds = True elif isinstance(self.__location, list): self.__comparison_loc = True self.__location = [ tools.get_infos_from_location(loc) for loc in self.__location ] for location in self.__location: if location["addresstype"] not in parameters.large_regions_type: if ( location["location_name_orig"].replace(" - ", "-") not in self.__shapefile_europe_cities.LAU_NAME.values ): self.__list_not_found_loc.append( location["location_name"]) self.__out_of_bounds = True else: if ( location["location_name"] not in ShapefileReader.list_countries() ): self.__list_not_found_loc.append( location["location_name"]) self.__out_of_bounds = True # Get type of comparison: if not self.__comparison_loc and not self.__comparison_time: self.comparison_type = 'single' elif self.__comparison_loc and not self.__comparison_time: self.comparison_type = 'location' elif not self.__comparison_loc and self.__comparison_time: self.comparison_type = 'time' else: self.comparison_type = 'location_time' # ----- process agg scale ----- if self.__aggreg_type == "year": self.__aggreg_type = "YS" elif self.__aggreg_type == "month": self.__aggreg_type = "MS" elif self.__aggreg_type == "week": self.__aggreg_type = "W" elif self.__aggreg_type == "day": self.__aggreg_type = "D" else: self.__aggreg_type = "raw" # Define aggreg for specific graphs if self.__aggreg_type == "raw" and self.__graph_type in [ "bar chart", "box plot"]: self.__aggreg_type = parameters.defaut_time_aggreg_if_necessary if self.__aggreg_type in ["raw", "W", "D"] and self.__graph_type == "warming stripes": self.__aggreg_type = "YS" if self.__graph_type == 'line chart' and self.delta_time >= parameters.threshold_rainbow: if not self.params['anomaly']: self.__aggreg_type = "MS" elif self.params['anomaly']: # can be changed to integrate other agreg. self.__aggreg_type = "MS" if self.__graph_type == 'heatmap': self.__aggreg_type = 'average over the period' if self.__graph_type == 'warming stripes' and self.__aggreg_type == 'YS': if not isinstance(self.__starttime, list): self.__starttime = pd.to_datetime( str(self.__starttime.year) + "-01-01") self.__endtime = pd.to_datetime( str(self.__endtime.year) + "-12-31") else: for i in range(len(self.__starttime)): self.__starttime[i] = pd.to_datetime( str(self.__starttime[i].year) + "-01-01") self.__endtime[i] = pd.to_datetime( str(self.__endtime[i].year) + "-12-31")
@property def params_after_processing(self): if not self.comparison_loc and not self.comparison_time: return [{ "starttime": self.starttime, "endtime": self.endtime, "location": self.location, "climate_variable": self.var_name, "graph_type": self.graph_type, "aggreg_type": self.aggreg_type, }] if self.comparison_loc and not self.comparison_time: res = [] for loc in self.location: res.append({ "starttime": self.starttime, "endtime": self.endtime, "location": loc, "climate_variable": self.var_name, "graph_type": self.graph_type, "aggreg_type": self.aggreg_type, }) return res if not self.comparison_loc and self.comparison_time: res = [] for start, end in zip(self.starttime, self.endtime): res.append({ "starttime": start, "endtime": end, "location": self.location, "climate_variable": self.var_name, "graph_type": self.graph_type, "aggreg_type": self.aggreg_type, }) return res if self.comparison_loc and self.comparison_time: res = [] for i in range(len(self.location)): res.append({ "starttime": self.starttime[i], "endtime": self.endtime[i], "location": self.location[i], "climate_variable": self.var_name, "graph_type": self.graph_type, "aggreg_type": self.aggreg_type, }) return res @property def context(self): return self.__graph_context @property def params(self): return self.__params @property def langage(self): return self.__langage @property def climate_variable(self): return self.__climate_variable @property def location(self): return self.__location @property def starttime(self): return self.__starttime @property def endtime(self): return self.__endtime @property def graph_type(self): return self.__graph_type @property def aggreg_type(self): return self.__aggreg_type @property def data(self): return self.__data @property def comparison_loc(self): return self.__comparison_loc @property def comparison_time(self): return self.__comparison_time @property def out_of_bounds(self): return self.__out_of_bounds @property def time_out_of_bounds(self): return self.__time_out_of_bounds @property def heatmap_not_supported(self): return self.__heatmap_not_supported @property def list_not_found_loc(self): return self.__list_not_found_loc @property def graph_context(self): return self.__graph_context @property def shapefile_europe(self): return self.__shapefile_europe @property def shapefile_europe_cities(self): return self.__shapefile_europe_cities @property def delta_time(self): if not self.__time_out_of_bounds: if not self.__comparison_time: dt = (self.__endtime - self.__starttime).days else: dt = (self.__endtime[0] - self.__starttime[0]).days return dt return None @property def delta_time_multi(self): if not self.__time_out_of_bounds: if not self.__comparison_time: dt = [(self.__endtime - self.__starttime).days] * self.NB_PARAM_ELM else: dt = [(endtime - starttime).days for (endtime, starttime) in zip(self.__endtime, self.__starttime)] return dt return None @property def NB_PARAM_ELM(self): return len(self.params_after_processing) @property def cross_tenses(self): if not self.comparison_time: return self.starttime <= parameters.time_max_data <= self.endtime return [self.starttime[i] <= parameters.time_max_data <= self.endtime[i] for i in range(self.NB_PARAM_ELM)]