# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
"""Tab 1 that contains the chatbot page
"""
import streamlit as st
import time
import psutil
from pyJoules.energy_meter import EnergyMeter
from loguru import logger
import diva.config as config
import diva.tools as src_tools
import diva.tools as tools
from diva import parameters
from diva import energy
from diva.chat import ModuleChat
from diva.graphs import service_graph_generation
from diva.logging.logger import Process
[docs]
def get_module_llm():
"""
Imports and returns the default language model module.
This function dynamically imports the `module_llm` from the `llm` package and returns
the default language model (`default_llm`) defined within that module.
Returns:
-------
object
The default language model (`default_llm`) from the `module_llm`.
"""
from diva.llm import llms
return llms.generator
[docs]
def main(tab1_options):
"""
Defines and manages the main page of the application within Tab 1, including chatbot interactions and graph visualizations.
This function initializes the necessary components (language model, chat module, configuration),
manages the chat history, and handles user input for interacting with the chatbot. It supports
both simple discussions and requests for visualizations, displaying the results accordingly.
Parameters:
----------
tab1_options : dict
A dictionary containing options related to the current tab, including language selection
and a flag to clear the conversation history.
Returns:
-------
None
"""
if "module_chat" not in st.session_state:
st.session_state.module_chat = ModuleChat()
module_chat = st.session_state.module_chat
if "module_config" not in st.session_state:
st.session_state.module_config = config.ModuleConfig(module_chat)
module_config = st.session_state.module_config
# p is an instance of psutil.Process()
if "p" not in st.session_state:
st.session_state.p = psutil.Process()
p = st.session_state.p
p.cpu_percent()
if "time_last_call_p" not in st.session_state:
st.session_state.time_last_call_p = time.time()
if "energy_meter" not in st.session_state:
st.session_state.energy_meter = EnergyMeter(energy.devices)
energy_meter = st.session_state.energy_meter
if "energy_records" not in st.session_state:
st.session_state.energy_records = energy.EnergyRecords()
energy_records = st.session_state.energy_records
source_lang = tools.langage_to_iso(tab1_options["langage"])
st.session_state.process_logging = Process(text_language='html')
# --------------- step 1 ---------------
# Display chat messages from history on app rerun
if "messages" not in st.session_state:
first_sentence = parameters.first_sentence
first_sentence = tools.translate_from_en(first_sentence, source_lang)
st.session_state.messages = {
"msg_0": {"role": "assistant", "content": first_sentence}
}
if "plots_history" not in st.session_state:
st.session_state.plots_history = {}
if "config_params" not in st.session_state:
st.session_state.config_params = {}
if "last_response_is_fig" not in st.session_state:
st.session_state.last_response_is_fig = False
if "conv_cleared" not in st.session_state:
st.session_state.conv_cleared = True
if "data_history" not in st.session_state:
st.session_state.data_history = {}
if tab1_options["clear_conv"]:
print("CLEAR CONVERSATION")
# Interface streamlit
first_sentence = parameters.first_sentence
first_sentence = tools.translate_from_en(first_sentence, source_lang)
st.session_state.messages = {
"msg_0": {"role": "assistant", "content": first_sentence}
}
st.session_state.plots_history = {}
st.session_state.config_params = {}
st.session_state.data_history = {}
st.session_state.last_response_is_fig = False
st.session_state.conv_cleared = True
# Chat & config
st.session_state.module_chat = ModuleChat()
module_chat = st.session_state.module_chat
st.session_state.module_config = config.ModuleConfig(module_chat)
module_config = st.session_state.module_config
# Logging view
st.session_state.process_logging = Process(text_language="html")
st.session_state.history_logging = ''
if st.session_state.history_logging != '':
st.session_state.log_container.markdown(st.session_state.history_logging)
st.rerun()
# re-build msg + figs
for msg_id in st.session_state.messages.keys():
message = st.session_state.messages[msg_id]
st.chat_message(message["role"]).write(tools.translate_chat_msgs(message["role"], message["content"], source_lang))
if msg_id in st.session_state.plots_history.keys():
fig = st.session_state.plots_history[msg_id]
if isinstance(fig, parameters.altair_type_fig) or isinstance(fig, parameters.altair_type_fig_concat):
st.altair_chart(fig, use_container_width=True)
else:
st.plotly_chart(fig, use_container_width=True)
if msg_id != list(st.session_state.plots_history.keys())[-1]:
title, data_csv = st.session_state.data_history[msg_id]
tools.add_download_button_after_prompts(title, data_csv,
message, st.session_state.config_params[msg_id],
key=msg_id)
# build buttons (download & feedback) for last fig
if (st.session_state.config_params != {}
and st.session_state.last_response_is_fig):
msg_plots = list(st.session_state.plots_history.keys())
msg_prompt = list(st.session_state.messages.keys())
if msg_plots != []:
msg_id = msg_plots[-1]
msg_id_ = msg_prompt[-1]
tools.add_download_feedback_button(
st.session_state.messages[msg_id_],
st.session_state.config_params[msg_id],
st.session_state.plots_history[msg_id],
)
# --------------- step 2 ---------------
# get user input
prompt = st.chat_input("Ask question...")
if st.session_state.conv_cleared:
if tab1_options['dev_mode']:
sentences_proposed = parameters.sentences_proposed_dev
else:
sentences_proposed = parameters.sentences_proposed_user
sentences_on_cards = list(sentences_proposed.keys())
cols = st.columns(3)
with cols[0]:
label = tools.translate_from_en(sentences_on_cards[0], source_lang)
if st.button(label):
prompt = tools.translate_from_en(
sentences_proposed[sentences_on_cards[0]], source_lang)
st.session_state.conv_cleared = False
with cols[1]:
label = tools.translate_from_en(sentences_on_cards[1], source_lang)
if st.button(label):
prompt = tools.translate_from_en(
sentences_proposed[sentences_on_cards[1]], source_lang)
st.session_state.conv_cleared = False
with cols[2]:
label = tools.translate_from_en(sentences_on_cards[2], source_lang)
if st.button(label):
prompt = tools.translate_from_en(
sentences_proposed[sentences_on_cards[2]], source_lang)
st.session_state.conv_cleared = False
if prompt:
prompt_is_empty = tools.chech_if_empty_prompt(prompt)
if prompt_is_empty:
prompt = False
st.session_state.messages[
"msg_{}".format(len(st.session_state.messages))
] = {"role": "user", "content": ""}
with st.chat_message("user"):
st.markdown("")
with st.chat_message("assistant"):
answer = "Please write something ..."
st.write_stream(tools.write_like_chatGPT(answer))
st.session_state.messages[
"msg_{}".format(len(st.session_state.messages))
] = {"role": "assistant", "content": answer}
# TODO faire plusieurs fonction à partir d'ici
# une pour la discussion, une pour la vis, une pour initialization du chat etc.
if prompt:
tools.send_user_event("prompt")
prompt = tools.clean_prompt(prompt)
# for energy recording
energy_records.clear_query()
cpu_percent = p.cpu_percent()
st.session_state.time_last_call_p = time.time()
st.session_state.messages["msg_{}".format(len(st.session_state.messages))] = {
"role": "user",
"content": prompt,
}
st.session_state.conv_cleared = False
# Afficher le message USER
with st.chat_message("user"):
st.markdown(prompt)
prompt = tools.translate_to_en(prompt, source_lang)
with st.spinner("Processing the request...⌛"):
module_chat.create_user_prompt(prompt)
assert module_chat.prompt is not None
# ------------- start energy recording gpu ----------------
energy_meter.start()
module_chat.prompt_classification().lower()
module_chat.prompt_rephrasing() # to make correct english sentences, with complementary memory inputs when memory needed
module_chat.is_prompt_in_scope()
module_chat.generate_text_answer()
energy_meter.stop()
energy_records.set_gpu(energy_meter)
# ------------- stop energy recording gpu ----------------
type_of_request = module_chat.prompt.type
print(module_chat.prompt.__repr__())
# 1 - Si requête = simpe discussion
if "discussion" in type_of_request:
llm_answer = module_chat.chat.llm_answer
llm_answer = tools.translate_from_en(llm_answer, source_lang)
with st.chat_message("assistant"):
st.write_stream(tools.write_like_chatGPT(llm_answer))
st.session_state.messages[
"msg_{}".format(len(st.session_state.messages))
] = {"role": "assistant", "content": llm_answer}
tools.update_logs(prompt, str(module_chat.prompt), llm_answer)
tools.add_feedback_msg_button(prompt, llm_answer)
st.session_state.last_response_is_fig = False
# 2 - Si requête = visualisation de graph
if "visualisation" in type_of_request or "visualization" in type_of_request:
with st.spinner("Extracting information...⌛"):
# ------------- start energy recording gpu ----------------
energy_meter.start()
module_config.prompt_to_config(module_chat.prompt)
energy_meter.stop()
energy_records.set_gpu(energy_meter)
# ------------- stop energy recording gpu ----------------
llm_answer = module_chat.chat.llm_answer
config_params = tools.convert_to_dict_config(module_config.config)
logger.info(f"missing info --> {module_config.missings}")
if not module_config.missings:
if len(module_config.config.not_in_shp) > 0:
if len(module_config.config.location.split(", ")) == 1:
extra = "location"
else:
extra = "locations"
llm_answer = (
f"My apologies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}"
f" but I can answer for the other {extra}."
)
# -------- 2.1 display msg for config --------
graph_gen = service_graph_generation.ServiceGeneratePlotlyGraph(
config_params, tab1_options["langage"], st.session_state.connected_user['user_type']
)
config_params['aggreg_type'] = graph_gen.aggreg_type
display_config = f"{llm_answer} \n" + tools.display_config(config_params) + " \n"
if config_params['graph_type'] == 'warming stripes':
display_config += "The calculations of warming stripes is based on the period 1971-2000."
with st.chat_message("assistant"):
display_config = tools.translate_from_en(display_config, source_lang)
st.write_stream(tools.write_like_chatGPT(display_config))
st.session_state.config_params["msg_{}".format(len(st.session_state.messages))] = config_params
st.session_state.messages[
"msg_{}".format(len(st.session_state.messages))
] = {"role": "assistant", "content": display_config}
st.session_state.last_response_is_fig = True
# -------- 2.2 display the figure --------
with st.spinner(tools.generate_waiting_for_graph_msg() + "⌛"):
graph_gen.generate()
msg_plots = list(st.session_state.plots_history.keys())
if msg_plots:
msg_id = msg_plots[-1]
title, data_csv = tools.add_download_feedback_button(
prompt,
config_params,
st.session_state.plots_history[msg_id],
"in_prompt",
)
st.session_state.data_history[msg_id] = (title, data_csv)
tools.update_logs(
prompt, str(module_chat.prompt), tools.from_dict_to_str(config_params)
)
elif module_config.missings:
LinkSentencesForMissings = [
tools.get_link_question_for_missing(link) for link in module_config.missings
] # TODO naming convention
if len(module_config.config.not_in_shp) > 0:
if len(module_config.config.not_in_shp) == 1:
extra = "another location"
else:
extra = "other locations"
ask_for_complete_request = (
f"My apoligies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}."
f" Could you please ask me again for {extra}?"
)
to_remove = list(
tools.get_link_question_for_missing.link_questions.values()
)[2]
LinkSentencesForMissings.remove(to_remove)
if len(LinkSentencesForMissings) > 0:
ask_for_complete_request += f" In addition, could you please precise me {src_tools.enumeration(LinkSentencesForMissings)}."
else:
ask_for_complete_request = "Thank you for your request. "
if config_params['climate_variable'].lower() in parameters.available_vars:
ask_for_complete_request += f"I can give you a graphical view of the {config_params['climate_variable']}. "
ask_for_complete_request += f"Could you please precise me {src_tools.enumeration(LinkSentencesForMissings)} ?"
# module_config.ask_missings can be used to generate a request to the user (_), and another demand used only by the
# LLM to determine whether the next prompt provides the missings param, in which case context from previous prompt is needed
_ = module_config.ask_missings()
with st.chat_message("assistant"):
ask_for_complete_request = tools.translate_from_en(ask_for_complete_request, source_lang)
st.write_stream(tools.write_like_chatGPT(ask_for_complete_request))
st.session_state.messages[
"msg_{}".format(len(st.session_state.messages))
] = {"role": "assistant", "content": ask_for_complete_request}
st.session_state.last_response_is_fig = False
tools.update_logs(prompt, str(module_chat.prompt), ask_for_complete_request)
tools.add_feedback_msg_button(prompt, ask_for_complete_request)
# final computation of energy recording for the current query
cpu_percent = p.cpu_percent()
duration = min(time.time() - st.session_state.time_last_call_p, 11)
st.session_state.time_last_call_p = time.time()
energy_records.set_cpu(
energy.get_energy_cpu(cpu_percent=cpu_percent, duration=duration) - energy.get_energy_cpu(cpu_percent=0, duration=duration)
)
energy_records.set_ram(
energy.get_energy_ram(cpu_percent=cpu_percent, duration=duration) - energy.get_energy_ram(cpu_percent=0, duration=duration)
)
st.session_state.energy_consumption = {"query_energy": round(energy_records.query / 1000, 2),
"query_CO2": energy_records.get_co2(type_="query"),
"query_water": energy_records.get_water(type_="query"),
"total_energy": round(energy_records.total / 1000, 2),
"total_CO2": energy_records.get_co2(type_="total"),
"total_water": energy_records.get_water(type_="total")}
st.session_state.history_logging = tools.update_logging_history(st.session_state.history_logging,
st.session_state.process_logging.doc(),
prompt)
st.session_state.log_container.markdown(st.session_state.history_logging)
nrg = st.session_state.energy_consumption
tools.show_energy_consumption(nrg)
st.rerun()