# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
import hashlib
import json
import os
import random
import time
import uuid
import jwt
from csv import writer
from datetime import datetime, timezone
from typing import TYPE_CHECKING
import base64
import argostranslate.translate as translator
import geopandas as gpd
import keycloak
import numpy as np
import pandas as pd
import pydeck as pdk
import streamlit as st
import streamlit.components.v1 as components
import unicodedata
import yaml
from geopy.geocoders import Nominatim
from keycloak import KeycloakOpenID
from loguru import logger
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import sent_tokenize, word_tokenize
from yaml.loader import SafeLoader
from urllib.parse import urlencode
import requests
from elasticsearch import Elasticsearch
from diva import parameters
if TYPE_CHECKING:
from diva.config import ModuleConfig
from diva.chat import ModuleChat
lemmatizer = WordNetLemmatizer()
[docs]
def send_user_event(event_type):
dt = datetime.now(timezone.utc)
iso_format = dt.strftime('%Y-%m-%dT%H:%M:%S.') + f"{int(dt.microsecond / 1000):03d}Z"
data = {
"@timestamp": iso_format,
"service_name": "DIVA",
"event_timestamp": iso_format,
"event_type": event_type,
"user_id": st.session_state.connected_user.get('userID', ''),
"session_id": st.session_state.connected_user.get('session_id', '')
}
es = Elasticsearch([parameters.elasticsearch_endpoint],
headers={"Authorization":"ApiKey {}".format(parameters.elasticsearch_apikey),
"Content-Type": "application/json"}
)
datastream_name = parameters.elasticsearch_ds_name
document = data
try:
es.index(index=datastream_name, document=document)
logger.info(f"Event sent to Elasticsearch: {event_type}")
except:
logger.error("Failed to send event to Elasticsearch")
[docs]
def find_best_match(
text: str,
list_: list | pd.Series | np.ndarray,
verbose=False,
extra=0,
prefilter=True,
remove_accents=True # TODO remove_accent est aussi une fonction
) -> tuple[str | None, int | None]:
"""
This function finds the best match of a text in a list/array/series. It is optimized to efficiently (time wise)
and accurately find the best match between the retrieved location parameters (text) and the locations
listed in the shapefiles (list_)
The function assigns a score to every element in list_, the highest score
meaning the best match with the text. To save time, a filtering process is applied at an early stage to exclude the
list/array/series (list_) elements for which the first letters do not match the first letters of the text.
Run time with table of 86 000 rows < 0.01 s
Parameters
----------
text: str
element for which the function searches the best match in list_
list_: list | pd.Series | np.ndarray,
where the best match is searched
verbose: bool
whether to print details. Mainly for debugging
extra: int
any integer can be added to add a bonus or a malus to the score. If a malus is added, it makes more difficult
to get a positive score, ensuring that the returned value is either None or very close to the text passed
in argument
prefilter: bool
default to True. The prefilter excludes some list_ elements that do not show any correspondence with the
first letters in text.
remove_accents: bool
default to True. Removes the accents in the list_ elements.
Returns
-------
tuple[str, int] or tuple[None, None]
tuple[None, None] if the best score <= 0.
tuple[str, int] if the best score > 0 with str the best match and int the index of the best match (int)
in the list/array/series
"""
best, idx = None, None
if type(list_) is list:
list_ = pd.Series(list_)
elif isinstance(list_, np.ndarray):
list_ = pd.Series(list_)
# remove unwanted characters, and list words in elem
text = clean_answer(text, ".,():;")
text = text.lower()
splits = text.split(" ")
if isinstance(list_, pd.Series):
if (
prefilter
): # speeds up the computation time for Series with many index, so True by default
words = text.lower().split(" ")
first_letters = [(w + " ")[0:2] for w in words[0: min(len(words), 2)]]
first_letters += [
w.capitalize()
for w in first_letters
if w.capitalize() not in first_letters
]
mask = list_.apply(lambda x: True if x[0:2] in first_letters else False)
poss_matches = list_.loc[mask]
poss_matches = poss_matches.apply(lambda x: x.lower())
else:
poss_matches = list_.apply(lambda x: x.lower())
if remove_accents:
poss_matches = poss_matches.apply(lambda x: clean_answer(x, characters=".,():;"))
else:
poss_matches = poss_matches.apply(
lambda x: clean_answer(x, characters=".,():;", remove_accent=False)
)
# filter by length of text with +1 margin
mini = min([int(len(text) * 0.75)] + [len(w) for w in text.split(" ")])
maxi = int(len(text) * 1.15)
if len(text) < 8:
maxi = int(len(text) * 1.15)
mask = list_.apply(
lambda x: (
True
if mini < max([len(w) for w in x.split(" ")]) <= maxi
else False
)
)
else:
mask = list_.apply(
lambda x: True if mini < max([len(w) for w in x.split(" ")]) else False
)
poss_matches = poss_matches.loc[mask]
score = poss_matches.apply(lambda x: sum([2 if w in x else -2 for w in splits]))
poss_matches = poss_matches.apply(lambda x: x.split(" "))
score2 = poss_matches.apply(lambda x: sum([1 if w in text else -1 for w in x]))
score = score + score2 + extra
if verbose and not prefilter:
print(pd.concat([list_, score], axis=1))
if score.max() >= 0:
idx = score.index[score.argmax()]
best = list_.loc[idx]
return best, idx
[docs]
def lemmatize(text: str | list[str]) -> str | list[str]:
"""
Lemmatizes text (e.g. removes markers of plural). It uses nltk word_tokenize.
The first call takes ~2s (the time for inner models to load). The other call have a negligible run time.
Capitalized words are not lemmatized.
Parameters
----------
text: str | list[str]
text to lemmatize. Can be a str or a list of str. If a str, the str sequence is split into words, and the
lemmatization is applied on each word individually.
Returns
-------
str | list[str]
str if text is of type str, list[str] if test is of type list[str]
"""
list_ = text
if type(text) is str:
list_ = word_tokenize(text)
lemmatized = []
for w in list_:
lemmatized.append(lemmatizer.lemmatize(w))
if type(text) is str:
lemmatized = " ".join(lemmatized)
return lemmatized
[docs]
def no_truncated_sentence(text: str) -> str:
"""
Removes unterminated sentences at the end of a text. Uses nltk sent_tokenize to split the text in sentences.
If there is a single unterminated sentence in the text, it is not removed.
Parameters
----------
text: str
the text to verify
Returns
-------
str
the text minus the unterminated sentence, if any and if the number of sentences in text > 1
"""
tokenized_text = sent_tokenize(text, "english")
if tokenized_text[-1][-1] not in [".", "!", "?", ")", "]", "\"", "'", "}"]:
# len(text) > 5 ensures that single text "float" scores generated by the LLM are not truncated.
if len(tokenized_text) > 1 and len(text) > 5:
text = " ".join(tokenized_text[0:-2])
return text
[docs]
def enumeration(list_: list, end="and") -> str:
"""
Takes a list of elements, for instances strings, and makes a text enumeration of the elements.
e.g. ['cats', 'dogs'] -> 'cats and dogs'
e.g. ['cats', 'dogs', 'birds'] -> 'cats, dogs and birds'.
The default separator between elements is the comma, and 'and' before the last element.
Parameters
----------
list_: list
a list of elements
end:
default to 'and'. Separator between the ultimate and penultimate element of the list.
Returns
-------
str:
A text enumeration
"""
if len(list_) > 0:
if len(list_) > 1:
text = ", ".join(list_[0:-1])
text += f" {end} {list_[-1]}"
else:
text = list_[0]
else:
text = ""
return text
[docs]
def remove_accents(text: str) -> str:
""" Removes accents in a text passed as argument. Returns a text without accents."""
return unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode("utf-8")
[docs]
def last_day_of_month(month: int, year: int) -> int:
""" Takes a year and a month as integers and returns an integer representing the last day of this month"""
if month in [1, 3, 5, 7, 8, 10, 12]:
day = 31
elif month == 2:
if year % 4 == 0:
day = 29
if year % 100 == 0 and year % 400 != 0:
day = 28
else:
day = 28
else:
day = 30
return day
[docs]
def is_time(text: str = "") -> bool:
"""
Takes a text as input and determines whether this text contains a time expression. Specifically, is_times searches
for textual month and seasons written in the text, or for digits presents in the text. If any of the two if found,
the text is assumed to be a time expressions and True is returned. This function was not created to be very
accurate, but more with the goal of excluding probable time expression from the list of retrieved parameters
that should not be time (e.g. locations which should contain neither months, nor seasons, nor digits).
Parameters
----------
text: str
text to assess
Returns
-------
bool
True if a time expression, False otherwise
"""
months_seasons = [
"january", "february", "march", "april", "may", "june", "july", "august", "september",
"october", "december", "jan", "feb", "mar", "apr", "jun", "jul", "aug", "sep", "oct",
"dec", "spring", "summer", "autumn", "fall", "winter"
]
is_time = False # TODO attention au scope
a = [x.isdigit() for x in text]
if True in a:
is_time = True
if is_time is False:
text = clean_answer(text, ".!?,")
text = text.lower().split(" ")
for word in text:
if word in months_seasons:
is_time = True
break
return is_time
[docs]
def set_max_new_tokens(prompt: str) -> int:
""" Estimates the max number of new tokens needed to answer a prompt. Returns an integer between 128 and 1024"""
max_new_tokens = 2 ** round(len(prompt) ** (1 / 3), 0)
if " and " in prompt:
max_new_tokens = max(max_new_tokens, 256)
else:
max_new_tokens = max(max_new_tokens, 128)
max_new_tokens = min(max_new_tokens, 1024)
return max_new_tokens
[docs]
def str_from_datetime(date: datetime) -> str:
""" Takes a date in datetime format as input and returns it in str format %Y-%m-%d """
return datetime.strftime(date, "%Y-%m-%d")
[docs]
def datetime_from_str(date: str) -> datetime:
""" Takes a date in str format %Y-%m-%d as input and returns it in datetime """
return datetime.strptime(date, "%Y-%m-%d")
[docs]
def clean_answer(text: str, characters: str | list[str] | tuple, remove_accent=True) -> str:
"""
Removes characters from a text.
Parameters
----------
text: str | list[str]
text from which removing characters.
characters: iterable
an iterable object contains the characters (str) to remove
remove_accent:
default to True. Also replaces the accentuated characters in the text by non-accentuated characters
Returns
-------
str
The text cleared from the specified characters (and from accentuated characters if parameter set as True)
"""
if remove_accent:
text = remove_accents(text)
for char in characters:
text = text.replace(char, "")
return text
[docs]
def append_if_not_in(list_: list, elem):
""" Appends an element (elem) to a list (list_) if the element is not already in the list. """
if elem not in list_:
list_.append(elem)
return list_
[docs]
def generate_map():
"""
Generates and displays an interactive map using Pydeck.
This function creates a map visualization using the Pydeck library, with the following properties:
- Map style: Road
- Initial view settings:
- Latitude: 47
- Longitude: 12.5
- Zoom level: 4
- Pitch: 50
This map visualization is intended to be displayed within a Streamlit application.
Returns:
None
"""
mapstyle = "road"
st.pydeck_chart(
pdk.Deck(
map_style=f"{mapstyle}", # 'light', 'dark', 'satellite', 'road'
initial_view_state=pdk.ViewState(
latitude=47,
longitude=12.5,
zoom=4,
pitch=50,
),
)
)
[docs]
def load_data(dir: str, type_: str):
data = pd.read_csv(dir + "/data.csv", sep=";")
type_data = data[data["type"] == type_]
df_perf = pd.read_csv(dir + "/perf_all.csv", sep=";")
return data, type_data, df_perf
[docs]
def export_results(dir, now, data, df_perf, perf, type_):
data.to_csv(dir + "/data.csv", sep=";", index=False)
# perfs as .csv
idx = df_perf["name"].loc[lambda x: x == f"PERF {type_.upper()}"].index
res1, res2 = perf.split(" = ") if " = " in perf else (None, perf)
if len(idx) == 0:
columns = df_perf.columns
new_row = pd.DataFrame(np.zeros([1, 4]), columns=columns)
new_row[["name", "result1", "date"]] = new_row[["name", "result1", "date"]].astype("object")
new_row.loc[:, :] = [f"PERF {type_.upper()}", res1, res2, now]
df_perf = pd.concat([df_perf, new_row], axis=0)
else:
df_perf.loc[idx, ["result1", "result2", "date"]] = [res1, res2, now]
df_perf.to_csv(dir + "/perf_all.csv", index=False, sep=";")
[docs]
def full_pipeline_config(
module_config: 'ModuleConfig',
module_chat: 'ModuleChat',
user_prompt: str,
clear: bool = True,
config: bool = True
):
"""
Runs pipeline to get graph parameters (config) from the user_prompt.
Parameters
----------
module_config: instance of ModuleConfig
module_chat: instance of ModuleChat
user_prompt: str
clear: bool
default to True. Whether to clear config historic (and chat historic) (True) or not (False)
config: bool
default to True. Whether to search for the config parameters in the user prompt (True) or not (False)
"""
if clear:
module_config.clear_history()
module_chat.create_user_prompt(user_prompt)
module_chat.prompt_classification()
module_chat.is_prompt_in_scope()
module_chat.prompt_rephrasing()
module_chat.generate_text_answer()
if config:
module_config.prompt_to_config()
[docs]
def generate_session_id():
"""
Generate a unique ID for the ongoing session.
This function checks if a session ID already exists in the Streamlit session state.
If it does not exist, a new unique session ID is generated using a combination of
the process ID, a random number, and the current timestamp, hashed with SHA-256.
The generated session ID is then stored in the Streamlit session state for later use.
Returns:
-------
str
The generated session ID, either newly created or retrieved from the session state.
"""
if "session_id" not in st.session_state:
# Generate a unique session ID using the process id and a random number
session_id = hashlib.sha256(os.urandom(24)).hexdigest()[0:20]
st.session_state["session_id"] = session_id
[docs]
def get_description_of_graph(llm, context):
"""
Generates a concise description of a graph based on the provided context.
This function utilizes a language model (LLM) to generate a brief and direct description
of the graph, focusing solely on the essential elements of the context. The description
avoids introductory phrases and is tailored to be succinct.
Parameters:
-----------
llm : object
The language model instance used to generate the description of the graph.
context : str
The context information regarding the graph, which serves as input for generating the description.
Returns:
--------
str
The generated description of the graph, concise and based on the provided context.
"""
prompt = (
"You are a chatbot for graph description, Provide a concise description of the graph using the "
"following information, without introductory phrases:"
)
prompt += context
resp = llm.generate(prompt)
return resp
# ------------ feedback for graphs ----------
[docs]
def from_dict_to_str(params):
config_str = (
f"Visualisation: starttime = {params['starttime']}, endtime = {params['endtime']}, "
f"location = {params['location']}, climate_variable = {params['climate_variable']}, "
f"graph_type = {params['graph_type']}, aggreg_type = {params['aggreg_type']}"
)
return config_str
[docs]
def drop_duplicated_columns(data):
new_data = data.T.drop_duplicates().T
x_col = [col for col in new_data.columns if "x_" in col]
if x_col:
x_col = x_col[0]
new_data = new_data.rename(columns={x_col: "x"})
return new_data
[docs]
def add_feedback(prompt, answer, feedback):
"""
Logs positive feedback for a generated graph.
This function records positive feedback for a graph in a CSV log file. It appends the current timestamp,
session ID, graph configuration, and feedback type ("Good") to the log file. If the log file does not exist,
it creates a new one.
Parameters:
-----------
config : dict
A dictionary containing the configuration information for the graph, which is logged along with feedback.
Returns:
--------
None
"""
if os.path.exists(parameters.path_logs):
print()
logs = pd.read_csv(parameters.path_logs, sep=";")
c1 = logs["SessionID"] == st.session_state.session_id
c2 = logs["OriginalUserPrompt"] == prompt
c3 = logs["LLM_Response"] == answer
c = c1 & c2 & c3
logs.loc[c, "Feedback"] = feedback
logs.to_csv(parameters.path_logs, sep=";", index=False)
else:
st.warning("log file not found!")
# ------------ feedback for all msg ----------
# ------------ update logs ----------
[docs]
def update_logs(UserPrompt, RephrasedPrompt, LLM_Response): # TODO naming convention
"""
Updates the log file with message history including user prompts, rephrased prompts, and LLM responses.
This function records information about user prompts, their rephrased versions, and the responses generated
by the LLM (Language Model). It logs this information along with the current timestamp and session ID.
If the log file already exists, the information is appended. Otherwise, a new file is created.
Parameters:
-----------
UserPrompt : str
The original prompt from the user.
RephrasedPrompt : str
The rephrased version of the user's prompt.
LLM_Response : str
The response generated by the LLM.
Returns:
--------
None
"""
logs_infos = [
pd.Timestamp.now(),
st.session_state.session_id,
UserPrompt,
RephrasedPrompt,
LLM_Response,
"",
]
if os.path.exists(parameters.path_logs):
with open(parameters.path_logs, "a") as f_object:
writer_object = writer(f_object, delimiter=";")
writer_object.writerow(logs_infos)
f_object.close()
if not os.path.exists(parameters.path_logs):
df = pd.DataFrame(
[logs_infos],
columns=[
"Timestamp",
"SessionID",
"OriginalUserPrompt",
"RephrasedPrompt",
"LLM_Response",
"Feedback",
],
)
df.to_csv(parameters.path_logs, index=False, sep=";")
[docs]
def str_to_list(text):
if text == "Unknown":
return text
newlist = text.split(", ")
if len(newlist) == 1:
return text
return newlist[0:10]
[docs]
def convert_to_dict_config(config):
"""
Converts a configuration object to a dictionary of parameters.
This function extracts parameters from the configuration object and converts them into a dictionary format.
The parameters include start times, end times, locations, the element of interest, graph type, and aggregation type.
Parameters:
-----------
config : object
The configuration object containing various settings.
Returns:
--------
dict
A dictionary containing the configuration parameters.
"""
dict_all = {
"starttime": str_to_list(config.start_time),
"endtime": str_to_list(config.end_time),
"location": str_to_list(config.location),
"climate_variable": config.climate_variable,
"graph_type": config.graph_type,
"aggreg_type": config.aggregation_type,
"anomaly": config.anomaly,
}
return dict_all
[docs]
def chech_if_empty_prompt(text):
while len(text) != 0 and text[0] == " ":
text = text[1:]
if text == "":
return True
else:
return False
[docs]
def write_like_chatGPT(text):
"""
Simulates typing of the provided text, one word at a time.
This function yields words from the text one at a time with a delay between each word, simulating a typing effect
similar to that of ChatGPT.
Parameters:
-----------
text : str
The text to be "typed out".
Yields:
-------
str
Words from the text with a space appended, one at a time.
"""
text = text.replace(" ", "§").replace(" § ", " ").replace("\n", "\n§")
for word in text.split("§"):
yield word + " "
time.sleep(0.05)
get_link_question_for_missing = GetLinkQuestionForMissing()
[docs]
def display_config(dict_params) -> str:
"""
Generates a descriptive text about the graph configuration to be displayed before the graph.
This function creates a text description of the graph based on the provided parameters. The description
includes details such as the element of interest, start and end times, location, and additional time indications
if multiple time ranges are provided. The text aims to provide context about what the graph represents and
encourages user interaction with the graph.
Parameters:
-----------
dict_params : dict
A dictionary containing configuration parameters, including 'climate_variable', 'starttime',
'endtime', and 'location'.
Returns:
--------
str
A randomly selected descriptive sentence about the graph configuration.
"""
clim_var = dict_params["climate_variable"]
starttime = dict_params["starttime"]
endtime = dict_params["endtime"]
location = dict_params["location"]
aggreg_type = dict_params["aggreg_type"]
second_time_indication = ""
if isinstance(location, str):
location = add_zipcode_to_city_name(location)
if isinstance(location, list):
txt = ""
for i in range(len(location) - 1):
loc_txt = add_zipcode_to_city_name(location[i])
txt += loc_txt + ", "
txt += add_zipcode_to_city_name(location[-1])
location = txt
if type(starttime) == list and type(endtime) == list:
starttime_all = starttime
endtime_all = endtime
starttime = starttime_all[0]
endtime = endtime_all[0] + ", "
for i in range(1, len(starttime_all) - 1):
second_time_indication += starttime_all[i] + " to " + endtime_all[i] + ", "
second_time_indication += starttime_all[-1] + " to " + endtime_all[-1]
if aggreg_type == "raw":
aggregation = 'raw data'
elif aggreg_type in ["day", "D"]:
aggregation = 'daily data'
elif aggreg_type in ["week", "W"]:
aggregation = 'weekly data'
elif aggreg_type in ["month", "MS"]:
aggregation = 'monthly data'
elif aggreg_type in ["year", "YS"]:
aggregation = 'yearly data'
elif aggreg_type in ["average over the period"]:
aggregation = 'average over the period'
# TODO define else
messages = [
f"The graph below shows the {clim_var} ({aggregation}) from {starttime} to {endtime}{second_time_indication} in {location}. This graph is interactive: you can zoom, filter, and explore the data.",
f"Below is a visualization of {clim_var} ({aggregation}) recorded between {starttime} and {endtime}{second_time_indication} in {location}. Feel free to interact with the graph to uncover more details.",
f"Check out the graph below displaying {clim_var} ({aggregation}) in {location} from {starttime} to {endtime}{second_time_indication}. You can zoom in and out, filter the data, and more!",
f"Here's a graph showing {clim_var} ({aggregation}) trends from {starttime} to {endtime}{second_time_indication} in {location}. It's interactive, so you can explore the data more deeply.",
f"Observe the {clim_var} ({aggregation}) data in {location} spanning from {starttime} to {endtime}{second_time_indication} in the graph below. The graph is interactive for your convenience.",
f"Presented below is an interactive graph of {clim_var} ({aggregation}) for {location}, covering the period from {starttime} to {endtime}{second_time_indication}. Zoom and filter to your liking.",
f"The following graph illustrates {clim_var} ({aggregation}) trends in {location} between {starttime} and {endtime}{second_time_indication}. Feel free to interact with it for more insights.",
f"Below you'll find an interactive chart showing {clim_var} ({aggregation}) in {location} from {starttime} to {endtime}{second_time_indication}. You can zoom, filter, and examine the data closely.",
f"Explore the {clim_var} ({aggregation}) data from {starttime} to {endtime}{second_time_indication} in {location} with the graph below. It's designed to be interactive for a better analysis experience.",
f"The graph below provides an interactive look at {clim_var} ({aggregation}) in {location} over the period from {starttime} to {endtime}{second_time_indication}. Discover patterns by zooming and filtering the data.",
]
txt = random.choice(messages)
return txt
[docs]
def find_country(display_name):
country = ""
list_info = display_name.split(", ")
country_dict = {
"Andorra": "Andorra",
"Беларусь": "Belarus",
"België": "Belgium",
"Belgique": "Belgium",
"Belgien": "Belgium",
"Bosna i Hercegovina": "Bosnia and Herzegovina",
"Босна и Херцеговина": "Bosnia and Herzegovina",
"България": "Bulgaria",
"Hrvatska": "Croatia",
"Kıbrıs": "Cyprus",
"Κύπρος": "Cyprus",
"Србија": "Serbia",
"Северна Македонија": "North Macedonia",
"Česko": "Czech Republic",
"Civitas Vaticana": "Vatican City",
"Città del Vaticano": "Vatican City",
"Danmark": "Denmark",
"Deutschland": "Germany",
"Eesti": "Estonia",
"Éire": "Ireland",
"España": "Spain",
"Finland": "Finland",
"France": "France",
"Ελλάς": "Greece",
"Italia": "Italy",
"Ísland": "Island",
"Ireland": "Ireland",
"Latvija": "Latvia",
"Liechtenstein": "Liechtenstein",
"Lietuva": "Lithuania",
"Magyarország": "Hungary",
"Malta": "Malta",
"Moldova": "Moldova",
"Monaco": "Monaco",
"Nederland": "Netherlands",
"Norge": "Norway",
"Österreich": "Austria",
"Polska": "Poland",
"Portugal": "Portugal",
"România": "Romania",
"Россия": "Russia",
"Srbija": "Serbia",
"San Marino": "San Marino",
"Schweiz": "Switzerland",
"Slovensko": "Slovakia",
"Slovenija": "Slovenia",
"Suisse": "Switzerland",
"Sverige": "Sweden",
"Svizzera": "Switzerland",
"Svizra": "Switzerland",
"Shqipëria": "Albania",
"Suomi": "Finland",
"Türkiye": "Turkey",
"Україна": "Ukraine",
"United Kingdom": "United Kingdom",
}
for info in list_info:
for k, v in country_dict.items():
if k in info:
country = v
return country
[docs]
@st.cache_data
def get_infos_from_location(location):
"""
Retrieves detailed information about a location, including latitude, longitude, and bounding box.
This function uses the Nominatim geocoding service to obtain geographic details about a given location. It returns
a dictionary with the original location name, the name used in the query, the type of address, center coordinates,
and bounding box coordinates.
Parameters:
-----------
location : str
The name of the location for which information is to be retrieved.
Returns:
--------
dict or None
A dictionary containing the location's name, address type, center coordinates, and bounding box.
Returns None if the location cannot be found.
"""
location_name = location
print(location_name)
if location_name == "Athens":
location_name = "Athens Greece"
geolocator = Nominatim(user_agent="myapplication")
location = geolocator.geocode(location, timeout=100)
# country = find_country(location.raw["display_name"])
logger.info(f"location: {location}")
if location is None:
return None
location_name_orig = location.raw["name"]
address_type = location.raw["addresstype"]
lat_lont_center = dict(zip(["lat", "lon"], np.round(location[1], 1)))
boundingbox = [float(x) for x in location.raw["boundingbox"]]
boundingbox = dict(zip(["lat_min", "lat_max", "lon_min", "lon_max"], boundingbox))
dict_location = dict(
zip(
[
"location_name_orig",
"location_name",
# "country",
"addresstype",
"lat_lont_center",
"boundingbox",
],
(
location_name_orig,
location_name,
# country,
address_type,
lat_lont_center,
boundingbox,
),
)
)
return dict_location
[docs]
def put_in_order_dates(start, end):
"""
Orders two dates such that the earlier date comes first.
This function takes two dates and returns them in ascending order. If the dates are the same, they are returned
in the same order they were provided.
Parameters:
-----------
start : datetime or str
The start date.
end : datetime or str
The end date.
Returns:
--------
tuple
A tuple containing the two dates in ascending order.
"""
if start < end:
return start, end
elif end < start:
return end, start
return start, end
[docs]
def langage_to_iso(language):
"""
Converts language names to their corresponding ISO 639-1 codes.
This function translates a given language name into its ISO 639-1 language code. It supports a selection of
common languages.
Parameters:
-----------
language : str
The name of the language.
Returns:
--------
str
The ISO 639-1 code of the language. Returns None if the language is not recognized.
"""
if language == "English":
return "en"
elif language == "French":
return "fr"
elif language == "German":
return "de"
elif language == "Spanish":
return "es"
elif language == "Italian":
return "it"
elif language == "Dutch":
return "nl"
[docs]
def init_state():
"""
Initializes the state for the Streamlit application.
This function sets up the initial state variables for the Streamlit application if they are not already present.
It sets the current tab to a default value and initializes the index for the tab.
Returns:
--------
None
"""
if "current_tab" not in st.session_state:
st.session_state["current_tab"] = parameters.tab1
st.session_state["index_tab"] = 0
# ---------------------------------
[docs]
def generate_connection_page(auth_url):
st.write("### Log in to your DESP account to access the application")
login_link = f'<a href="{auth_url}" target="_self">Click here to log in</a>'
st.markdown(login_link, unsafe_allow_html=True)
[docs]
def check_token_validity(keycloak_openid):
token_info = keycloak_openid.introspect(st.session_state.get("access_token", ''))
return token_info["active"]
[docs]
def verify_authentification_keycloak():
"""
Verifies user authentication with Keycloak.
This function initiates the authentication process using Keycloak's OpenID Connect protocol.
It constructs the necessary URL for user login and redirects the user to Keycloak's login page if
they are not already authenticated. If the user provides valid credentials, an access token is obtained
and stored in the session state. The function checks if the access token is successfully stored in the
session to confirm the user's authentication status.
Parameters:
-----------
None
Returns:
--------
bool: Returns True if the user is successfully authenticated and the access token is stored in the session.
Returns False if the authentication fails or if no valid access token is found.
"""
IAM_URL = parameters.server_url
REALM = parameters.realm_name
CLIENT_ID = parameters.client_id
SERVICE_URL = parameters.redirect_uri
CLIENT_SECRET = parameters.client_secret
keycloak_openid = KeycloakOpenID(server_url=IAM_URL,
client_id=CLIENT_ID,
realm_name=REALM,
client_secret_key=CLIENT_SECRET)
try:
auth_url = keycloak_openid.auth_url(
redirect_uri=SERVICE_URL,
scope="openid",
state=str(uuid.uuid4()))
except:
auth_url = IAM_URL
if "access_token" not in st.session_state:
generate_connection_page(auth_url)
code = st.query_params.get('code')
if not code:
if "auth_initiated" not in st.session_state:
st.session_state["auth_initiated"] = True
return False, keycloak_openid
if code:
try:
access_token_keycloak = keycloak_openid.token(
grant_type='authorization_code',
code=code,
redirect_uri=SERVICE_URL
)
public_key = keycloak_openid.public_key()
formatted_key = "-----BEGIN PUBLIC KEY-----\n" + public_key + "\n-----END PUBLIC KEY-----"
decoded_access_token = jwt.decode(
access_token_keycloak["access_token"],
key=formatted_key,
algorithms=["RS256"],
)
st.session_state["access_token"] = access_token_keycloak['access_token']
st.session_state["refresh_token"] = access_token_keycloak['refresh_token']
st.session_state["url_code"] = code
userinfo = keycloak_openid.userinfo(access_token_keycloak['access_token'])
st.session_state.connected_user = {"user": userinfo["preferred_username"],
"userID": userinfo["sub"],
"session_id": decoded_access_token['sid'],
"info_url": "http://diva.destine.eurobios.com:8000/index.html"}
if 'access_group' in decoded_access_token and "DPAD_Direct_Access" in decoded_access_token["access_group"]:
st.session_state.connected_user['user_type'] = "granted"
else:
st.session_state.connected_user['user_type'] = "normal"
send_user_event("connection")
except keycloak.exceptions.KeycloakPostError:
return False, keycloak_openid
if "auth_initiated" in st.session_state :
del st.session_state["auth_initiated"]
st.rerun()
token_is_active = check_token_validity(keycloak_openid)
connection_authorized = "access_token" in st.session_state and token_is_active
if not connection_authorized:
st.session_state.clear()
st.rerun()
return connection_authorized, keycloak_openid
[docs]
def log_out_keycloak(keycloak_openid):
keycloak_openid.logout(st.session_state["refresh_token"])
send_user_event("disconnection")
[docs]
def generate_waiting_for_graph_msg():
"""
Generates a random message indicating that a graph is being generated.
This function selects a random message from a predefined list that informs the user that the graph
is in the process of being created. It is useful for providing feedback during long-running
operations like graph generation.
Returns:
--------
str: A randomly chosen message from the list indicating that the graph is being generated.
"""
messages = [
"The graph is currently being generated...",
"Creating the graph, please wait...",
"Building the graph, hang tight...",
"The graph is being constructed...",
"Graph generation in progress...",
"Compiling the graph, just a moment...",
"Preparing the graph, please stand by...",
"The graph is on its way...",
"Assembling the graph, one moment please...",
"The graph is being created, hold on...",
]
return random.choice(messages)
[docs]
def get_10_colors():
colors = [
"#636efa", # Bleu
"#ef553b", # Rouge
"#00cc96", # Vert
"#ab63fa", # Violet
"#ffa15a", # Orange
"#19d3f3", # Cyan
"#ff6692", # Rose
"#b6e880", # Vert clair
"#ff97ff", # Magenta
"#feca57", # Jaune
]
return colors
[docs]
def get_10_colors_gray():
colors_with_gray_filter = [
"#484e96", # Bleu assombri
"#8c3a2c", # Rouge assombri
"#009977", # Vert assombri
"#7d48a8", # Violet assombri
"#b36b3d", # Orange assombri
"#149aa9", # Cyan assombri
"#994d69", # Rose assombri
"#869f58", # Vert clair assombri
"#b565b5", # Magenta assombri
"#b68240", # Jaune assombri
]
return colors_with_gray_filter
[docs]
def translate_from_en(text, source_lang):
if source_lang != "en":
return translator.translate(text, "en", source_lang)
else:
return text
[docs]
def translate_to_en(text, source_lang):
"""
Translates a given text to English, with additional handling for graph type terminology.
This function first checks if the source language is not English. If the source language is not English,
it loads a mapping file to translate specific graph type terms to their English equivalents. The mapping
is applied to the text before translating the entire text to English. If the source language is already English,
the original text is returned without modification.
Parameters:
-----------
text (str): The text to be translated to English.
source_lang (str): The language of the input text. If it is "en", the text is returned as is.
Returns:
--------
str: The translated text in English, with graph type terms mapped if applicable.
"""
if source_lang != "en":
with open(parameters.graph_type_langage_mapping_to_en) as f:
graph_types_mapping = json.load(f)
if source_lang in graph_types_mapping.keys():
mapping = graph_types_mapping[source_lang]
for key in mapping.keys():
if key in text.lower():
text = text.replace(key, '"' + mapping[key] + '"')
text = text.replace(key.capitalize(), '"' + mapping[key] + '"')
return translator.translate(text, source_lang, "en")
else:
return text
[docs]
@st.cache_data
def load_shapefile_cities():
"""
Loads a shapefile of European cities and a world CSV file.
This function uses the GeoPandas library to read a shapefile of European cities and loads a world map
from a CSV file containing geographic data. The shapefile contains geometrical data for European cities,
while the world CSV provides additional geographic information.
Returns:
--------
tuple: A tuple containing two elements:
- GeoDataFrame: The GeoDataFrame containing the European cities shapefile.
- DataFrame: The DataFrame containing the world map data from the CSV.
"""
shapefile = gpd.read_file(
parameters.path_data + parameters.shapefile_europe_cities)
world = pd.read_csv("../data/geo/world.csv")
return shapefile, world
[docs]
def get_city_zipcode(city):
"""
Retrieves the country and LAU ID (Local Administrative Unit) for a given city.
This function searches for the provided city name in a shapefile of European cities. If the city is found,
it retrieves the corresponding country code and LAU ID from the shapefile, then uses a world map CSV to
retrieve the full country name based on the country code.
Parameters:
-----------
city (str): The name of the city for which the country and LAU ID are to be retrieved.
Returns:
--------
tuple: A tuple containing:
- str or None: The name of the country corresponding to the city. Returns None if the city is not found.
- str or None: The LAU ID of the city. Returns None if the city is not found.
"""
shapefile, world = load_shapefile_cities()
shapefile = shapefile[shapefile.LAU_NAME == city]
if shapefile.shape[0] == 0:
return None, None
cntr_code = shapefile["CNTR_CODE"].values[0]
lau_id = shapefile["LAU_ID"].values[0]
world = world[world['alpha2'] == cntr_code.lower()]
cntr = world['name'].values[0]
return cntr, lau_id
[docs]
def add_zipcode_to_city_name(location):
"""
Adds the country code to the city name if it's not already included.
This function first retrieves detailed information about a location, including the city name. It then checks
if the location is already in the list of countries. If the location is not in the list, it appends the country
code to the city name in the format `city (country_code)`.
Parameters:
-----------
location (str): The name of the location (city or country) to which the country code will be added.
Returns:
--------
str: The location name, possibly updated to include the country code if it's not already in the list of countries.
"""
loc_info = get_infos_from_location(location)
cntr, lau_id = get_city_zipcode(loc_info['location_name_orig'])
if location not in parameters.list_countries:
location += f" ({cntr})"
return location
[docs]
def update_logging_history(logging_history, logging, prompt):
txt = ''
txt += '<hr>'
txt += "<b>" + prompt + "</b>" + '<br />'
txt += '<hr>'
txt += logging
if logging_history != '':
txt += '<br /><br />'
txt += logging_history
return txt
[docs]
def interpolate_color(start_color, end_color, factor: float):
return tuple(
int(start + (end - start) * factor)
for start, end in zip(start_color, end_color)
)
[docs]
def get_code_of_function(function_name):
from diva import data, graphs, config, chat, llm
import inspect
intro_sentences = [
"Below is the Python implementation of $$$$:",
"Here’s the Python code for $$$$:",
"The Python code for $$$$ is provided below:",
"This is the Python script for $$$$:",
"Here’s how $$$$ is implemented in Python:",
"The following is the Python code for $$$$:",
"Below is the Python snippet for $$$$:",
"This is the Python code representing $$$$:",
"Here’s the Python example for $$$$:",
"The Python implementation of $$$$ is shown below:"
]
mapping_function = {
# data.dataset
"get_data": data.dataset.DataCollection.get_data,
"sample_time": data.dataset.DataCollection.sample_time,
"apply_masks": data.dataset.DataCollection.apply_masks,
"spatial_aggregation": data.dataset.DataCollection.spatial_aggregation,
"temporal_aggregation": data.dataset.DataCollection.temporal_aggregation,
"aggregate": data.dataset.DataCollection.aggregate,
"get_values": data.dataset.DataCollection.get_values,
"get_mean": data.dataset.DataCollection.get_mean,
"get_max": data.dataset.DataCollection.get_max,
"get_min": data.dataset.DataCollection.get_min,
"get_groupe_values": data.dataset.DataCollection.get_groupe_values,
# data.shapefile_masker
"read_city": data.shapefile_masker.ShapefileReader.read_city,
"read_country": data.shapefile_masker.ShapefileReader.read_country,
# graphs.graph_generator
"textual_content": graphs.graph_generator.GraphTexts.textual_content,
"translate": graphs.graph_generator.GraphTexts.translate,
"initialise_data_collection": graphs.graph_generator.IGraphGenerator.initialise_data_collection,
"process_params": graphs.graph_generator.IGraphGenerator.process_params,
# graphs.service_graph_generation
"generate": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate,
"generate_lineplot": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_lineplot,
"generate_lineplot_mean": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_lineplot_mean,
"generate_rainbow": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_rainbow,
"generate_barplot": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_barplot,
"generate_distribution": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_distribution,
"generate_warming_stripes": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_warming_stripes,
"generate_boxplot": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_boxplot,
"generate_map": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.generate_map,
"matching_case": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.matching_case,
"checks": graphs.service_graph_generation.ServiceGeneratePlotlyGraph.checks,
# config.ModuleConfig
"ask_missings": config.ModuleConfig.ask_missings,
"asks_missings_strategy": config.services.AsksMissings,
"__asks_missings_strategy": config.services.AsksMissings,
"asksmissings": config.services.AsksMissings,
"prompt_to_config": config.ModuleConfig.prompt_to_config,
"creation_strategy": config.services.ConfigCreation._ConfigCreation__prompt_to_config,
"__creation_strategy": config.services.ConfigCreation._ConfigCreation__prompt_to_config,
"__prompt_to_config": config.services.ConfigCreation._ConfigCreation__prompt_to_config,
"configcreation": config.services.ConfigCreation,
"clear_local_memory": config.services.ConfigCreation.clear_local_memory,
"get_climate_variable": config.services.ConfigCreation.get_climate_variable,
"get_locations": config.services.ConfigCreation.get_locations,
"get_time_expressions": config.services.ConfigCreation.get_time_expressions,
"get_graph_type": config.services.ConfigCreation.get_graph_type,
"get_aggregation_frequency": config.services.ConfigCreation.get_aggregation_frequency,
"get_aggregation_operator": config.services.ConfigCreation.get_aggregation_operator,
"get_anomaly": config.services.ConfigCreation.get_anomaly,
"nltk_ner": config.services.ConfigCreation._ConfigCreation__nltk_ner,
"__nltk_ner": config.services.ConfigCreation._ConfigCreation__nltk_ner,
"nltk_ner_helper": config.services.ConfigCreation._ConfigCreation__nltk_ner_helper,
"__nltk_ner_helper": config.services.ConfigCreation._ConfigCreation__nltk_ner_helper,
"spacy_ner": config.services.ConfigCreation._ConfigCreation__spacy_ner,
"__spacy_ner": config.services.ConfigCreation._ConfigCreation__spacy_ner,
"complete_missings": config.ModuleConfig.complete_missings,
"get_missings_strategy": config.services.CompletionWithLast._CompletionWithLast__complete_missings,
"__get_missings_strategy": config.services.CompletionWithLast._CompletionWithLast__complete_missings,
"__complete_missings": config.services.CompletionWithLast._CompletionWithLast__complete_missings,
"completionwithlast": config.services.CompletionWithLast,
"process_unknown_location": config.services.CompletionWithLast._CompletionWithLast__process_unknown_location,
"__process_unknown_location": config.services.CompletionWithLast._CompletionWithLast__process_unknown_location,
"process_unknown_climate_variable": config.services.CompletionWithLast._CompletionWithLast__process_unknown_climate_variable,
"__process_unknown_climate_variable": config.services.CompletionWithLast._CompletionWithLast__process_unknown_climate_variable,
"process_unknown_times": config.services.CompletionWithLast._CompletionWithLast__process_unknown_times,
"__process_unknown_times": config.services.CompletionWithLast._CompletionWithLast__process_unknown_times,
"process_unknown_graph_type": config.services.CompletionWithLast._CompletionWithLast__process_unknown_graph_type,
"__process_unknown_graph_type": config.services.CompletionWithLast._CompletionWithLast__process_unknown_graph_type,
"process_unknown_aggregation_type": config.services.CompletionWithLast._CompletionWithLast__process_unknown_aggregation_type,
"__process_unknown_aggregation_type": config.services.CompletionWithLast._CompletionWithLast__process_unknown_aggregation_type,
"process_unknown_aggregation_operator": config.services.CompletionWithLast._CompletionWithLast__process_unknown_aggregation_operator,
"__process_unknown_aggregation_operator": config.services.CompletionWithLast._CompletionWithLast__process_unknown_aggregation_operator,
# chat.config
"prompt_classification": chat.ModuleChat.prompt_classification,
"classification_strategy": chat.services.PromptClassification,
"__classification_strategy": chat.services.PromptClassification,
"promptclassification": chat.services.PromptClassification,
"is_prompt_in_scope": chat.ModuleChat.is_prompt_in_scope,
"chatbot_scope_strategy": chat.services.ChatbotScopes._ChatbotScopes__prompt_and_scope,
"__chatbot_scope_strategy": chat.services.ChatbotScopes._ChatbotScopes__prompt_and_scope,
"prompt_and_scope": chat.services.ChatbotScopes._ChatbotScopes__prompt_and_scope,
"__prompt_and_scope": chat.services.ChatbotScopes._ChatbotScopes__prompt_and_scope,
"chatbotscopes": chat.services.ChatbotScopes,
"__add_context": chat.services.ChatbotScopes._ChatbotScopes__add_context,
"is_memory_needed": chat.ModuleChat.is_memory_needed,
"memory_strategy": chat.services.IsMemoryNeeded._IsMemoryNeeded__is_memory_needed,
"__memory_strategy": chat.services.IsMemoryNeeded._IsMemoryNeeded__is_memory_needed,
"ismemoryneeded": chat.services.IsMemoryNeeded,
"prompt_rephrasing": chat.ModuleChat.prompt_rephrasing,
"rephrasing_strategy": chat.services.PromptRephrasing._PromptRephrasing__prompt_rephrasing,
"__rephrasing_strategy": chat.services.PromptRephrasing._PromptRephrasing__prompt_rephrasing,
"__prompt_rephrasing": chat.services.PromptRephrasing._PromptRephrasing__prompt_rephrasing,
"promptrephrasing": chat.services.PromptRephrasing,
"stateless_rephrasing": chat.services.PromptRephrasing._PromptRephrasing__stateless_rephrasing,
"__stateless_rephrasing": chat.services.PromptRephrasing._PromptRephrasing__stateless_rephrasing,
"check_rephrasing_with_bleu": chat.services.PromptRephrasing._PromptRephrasing__check_rephrasing_with_bleu,
"__check_rephrasing_with_bleu": chat.services.PromptRephrasing._PromptRephrasing__check_rephrasing_with_bleu,
"get_where_back": chat.services.PromptRephrasing._PromptRephrasing__get_where_back,
"__get_where_back": chat.services.PromptRephrasing._PromptRephrasing__get_where_back,
"generate_text_answer": chat.ModuleChat.generate_text_answer,
"gen_chatbot_answer_strategy": chat.services.GenChatbotAnswer._GenChatbotAnswer__generate,
"__gen_chatbot_answer_strategy": chat.services.GenChatbotAnswer._GenChatbotAnswer__generate,
"genchatbotanswer": chat.services.GenChatbotAnswer,
"not_asking_more_info": chat.services.GenChatbotAnswer._GenChatbotAnswer__not_asking_more_info,
"__not_asking_more_info": chat.services.GenChatbotAnswer._GenChatbotAnswer__not_asking_more_info,
"add_apologies": chat.services.GenChatbotAnswer._GenChatbotAnswer__add_apologies,
"__add_apologies": chat.services.GenChatbotAnswer._GenChatbotAnswer__add_apologies,
# llm
"sacrebleueval": llm.eval_tasks.SacreBleuEval,
"commandeval": llm.eval_tasks.CommandEval,
"scopeeval": llm.eval_tasks.ScopeEval,
"similarityeval": llm.eval_tasks.SimilarityEval,
"toxicityeval": llm.eval_tasks.ToxicityEval,
}
if function_name == 'GET_ONLY_FUNCTION_NAMES':
return list(mapping_function.keys())
function = mapping_function[function_name]
intro = random.choice(intro_sentences)
intro = intro.replace("$$$$", function_name)
code = inspect.getsource(function)
if "_strategy" in function_name:
intro = (
f" {function_name} is a façade name. Ultimately, it refers to the following:"
)
gen_code = f"""[IS_GEN_FUN_CODE]{intro}
```python
{code}
```
"""
if code.split()[0] == "class":
gen_code += "This class is provided as information. It needs to be used within the proper context to work correctly."
else:
gen_code += "This function is provided as information. It may need to be used within its own class to function properly."
return gen_code
[docs]
def get_list_diva_functions():
return get_code_of_function('GET_ONLY_FUNCTION_NAMES')
[docs]
def show_energy_consumption(nrg):
if nrg:
with st.session_state.consumption_container:
st.write(f"• Energy ⚡ \n↪ Last request: {round(nrg['query_energy'], 2)}kJ \n↪Total session: :red[{round(nrg['total_energy'], 2)}]kJ")
st.write(f"• CO2 🫧 \n↪ Last request: {round(nrg['query_CO2'], 2)}g \n↪Total session: :red[{round(nrg['total_CO2'], 2)}]g")
st.write(f"• Water 💧 \n↪ Last request: {round(nrg['query_water'], 2)}mL \n↪Total session: :red[{round(nrg['total_water'], 2)}]mL")
with st.expander("**Compare with real-world data**"):
st.write(f"• A search on Google uses 0.3 kWh in average, thus 1.08 kJ")
st.write(f"• A query to ChatGPT uses 2.9 kWh in average, thus 10.4 kJ")
st.write(f"↪ See more information in the documentation section [Environmental impact](http://diva.destine.eurobios.com:8000/environmental_impact.html)")
[docs]
def clean_prompt(prompt):
while prompt[-1] == ' ' or prompt[-1] == '\n':
prompt = prompt[:-1]
return prompt
[docs]
def translate_chat_msgs(role, msg, language):
if role == 'user' or language == 'en':
return msg
msg = translate_from_en(msg, language)
return msg
[docs]
def get_month_names():
return [
"Jan",
"Feb",
"Mar",
"Apr",
"May",
"Jun",
"Jul",
"Aug",
"Sep",
"Oct",
"Nov",
"Dec",
]
[docs]
def dic_month_id_to_name():
return {
1: "Jan",
2: "Feb",
3: "Mar",
4: "Apr",
5: "May",
6: "Jun",
7: "Jul",
8: "Aug",
9: "Sep",
10: "Oct",
11: "Nov",
12: "Dec",
}
[docs]
def image_to_base64(image_path):
with open(image_path, "rb") as img_file:
return base64.b64encode(img_file.read()).decode()