Source code for diva.gui.service_streamlit.tab1

# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.

"""Tab 1 that contains the chatbot page
"""
import streamlit as st
import time
import psutil
from pyJoules.energy_meter import EnergyMeter

from loguru import logger

import diva.config as config
import diva.tools as src_tools
import diva.tools as tools
from diva import parameters
from diva import energy
from diva.chat import ModuleChat
from diva.graphs import service_graph_generation
from diva.logging.logger import Process


[docs] def get_module_llm(): """ Imports and returns the default language model module. This function dynamically imports the `module_llm` from the `llm` package and returns the default language model (`default_llm`) defined within that module. Returns: ------- object The default language model (`default_llm`) from the `module_llm`. """ from diva.llm import llms return llms.generator
[docs] def main(tab1_options): """ Defines and manages the main page of the application within Tab 1, including chatbot interactions and graph visualizations. This function initializes the necessary components (language model, chat module, configuration), manages the chat history, and handles user input for interacting with the chatbot. It supports both simple discussions and requests for visualizations, displaying the results accordingly. Parameters: ---------- tab1_options : dict A dictionary containing options related to the current tab, including language selection and a flag to clear the conversation history. Returns: ------- None """ if "module_chat" not in st.session_state: st.session_state.module_chat = ModuleChat() module_chat = st.session_state.module_chat if "module_config" not in st.session_state: st.session_state.module_config = config.ModuleConfig(module_chat) module_config = st.session_state.module_config # p is an instance of psutil.Process() if "p" not in st.session_state: st.session_state.p = psutil.Process() p = st.session_state.p p.cpu_percent() if "time_last_call_p" not in st.session_state: st.session_state.time_last_call_p = time.time() if "energy_meter" not in st.session_state: st.session_state.energy_meter = EnergyMeter(energy.devices) energy_meter = st.session_state.energy_meter if "energy_records" not in st.session_state: st.session_state.energy_records = energy.EnergyRecords() energy_records = st.session_state.energy_records source_lang = tools.langage_to_iso(tab1_options["langage"]) st.session_state.process_logging = Process(text_language='html') # --------------- step 1 --------------- # Display chat messages from history on app rerun if "messages" not in st.session_state: first_sentence = parameters.first_sentence first_sentence = tools.translate_from_en(first_sentence, source_lang) st.session_state.messages = { "msg_0": {"role": "assistant", "content": first_sentence} } if "plots_history" not in st.session_state: st.session_state.plots_history = {} if "config_params" not in st.session_state: st.session_state.config_params = {} if "last_response_is_fig" not in st.session_state: st.session_state.last_response_is_fig = False if "conv_cleared" not in st.session_state: st.session_state.conv_cleared = True if "data_history" not in st.session_state: st.session_state.data_history = {} if tab1_options["clear_conv"]: print("CLEAR CONVERSATION") # Interface streamlit first_sentence = parameters.first_sentence first_sentence = tools.translate_from_en(first_sentence, source_lang) st.session_state.messages = { "msg_0": {"role": "assistant", "content": first_sentence} } st.session_state.plots_history = {} st.session_state.config_params = {} st.session_state.data_history = {} st.session_state.last_response_is_fig = False st.session_state.conv_cleared = True # Chat & config st.session_state.module_chat = ModuleChat() module_chat = st.session_state.module_chat st.session_state.module_config = config.ModuleConfig(module_chat) module_config = st.session_state.module_config # Logging view st.session_state.process_logging = Process(text_language="html") st.session_state.history_logging = '' if st.session_state.history_logging != '': st.session_state.log_container.markdown(st.session_state.history_logging) st.rerun() # re-build msg + figs for msg_id in st.session_state.messages.keys(): message = st.session_state.messages[msg_id] st.chat_message(message["role"]).write(tools.translate_chat_msgs(message["role"], message["content"], source_lang)) if msg_id in st.session_state.plots_history.keys(): fig = st.session_state.plots_history[msg_id] if isinstance(fig, parameters.altair_type_fig) or isinstance(fig, parameters.altair_type_fig_concat): st.altair_chart(fig, use_container_width=True) else: st.plotly_chart(fig, use_container_width=True) if msg_id != list(st.session_state.plots_history.keys())[-1]: title, data_csv = st.session_state.data_history[msg_id] tools.add_download_button_after_prompts(title, data_csv, message, st.session_state.config_params[msg_id], key=msg_id) # build buttons (download & feedback) for last fig if (st.session_state.config_params != {} and st.session_state.last_response_is_fig): msg_plots = list(st.session_state.plots_history.keys()) msg_prompt = list(st.session_state.messages.keys()) if msg_plots != []: msg_id = msg_plots[-1] msg_id_ = msg_prompt[-1] tools.add_download_feedback_button( st.session_state.messages[msg_id_], st.session_state.config_params[msg_id], st.session_state.plots_history[msg_id], ) # --------------- step 2 --------------- # get user input prompt = st.chat_input("Ask question...") if st.session_state.conv_cleared: if tab1_options['dev_mode']: sentences_proposed = parameters.sentences_proposed_dev else: sentences_proposed = parameters.sentences_proposed_user sentences_on_cards = list(sentences_proposed.keys()) cols = st.columns(3) with cols[0]: label = tools.translate_from_en(sentences_on_cards[0], source_lang) if st.button(label): prompt = tools.translate_from_en( sentences_proposed[sentences_on_cards[0]], source_lang) st.session_state.conv_cleared = False with cols[1]: label = tools.translate_from_en(sentences_on_cards[1], source_lang) if st.button(label): prompt = tools.translate_from_en( sentences_proposed[sentences_on_cards[1]], source_lang) st.session_state.conv_cleared = False with cols[2]: label = tools.translate_from_en(sentences_on_cards[2], source_lang) if st.button(label): prompt = tools.translate_from_en( sentences_proposed[sentences_on_cards[2]], source_lang) st.session_state.conv_cleared = False if prompt: prompt_is_empty = tools.chech_if_empty_prompt(prompt) if prompt_is_empty: prompt = False st.session_state.messages[ "msg_{}".format(len(st.session_state.messages)) ] = {"role": "user", "content": ""} with st.chat_message("user"): st.markdown("") with st.chat_message("assistant"): answer = "Please write something ..." st.write_stream(tools.write_like_chatGPT(answer)) st.session_state.messages[ "msg_{}".format(len(st.session_state.messages)) ] = {"role": "assistant", "content": answer} # TODO faire plusieurs fonction à partir d'ici # une pour la discussion, une pour la vis, une pour initialization du chat etc. if prompt: tools.send_user_event("prompt") prompt = tools.clean_prompt(prompt) # for energy recording energy_records.clear_query() cpu_percent = p.cpu_percent() st.session_state.time_last_call_p = time.time() st.session_state.messages["msg_{}".format(len(st.session_state.messages))] = { "role": "user", "content": prompt, } st.session_state.conv_cleared = False # Afficher le message USER with st.chat_message("user"): st.markdown(prompt) prompt = tools.translate_to_en(prompt, source_lang) with st.spinner("Processing the request...⌛"): module_chat.create_user_prompt(prompt) assert module_chat.prompt is not None # ------------- start energy recording gpu ---------------- energy_meter.start() module_chat.prompt_classification().lower() module_chat.prompt_rephrasing() # to make correct english sentences, with complementary memory inputs when memory needed module_chat.is_prompt_in_scope() module_chat.generate_text_answer() energy_meter.stop() energy_records.set_gpu(energy_meter) # ------------- stop energy recording gpu ---------------- type_of_request = module_chat.prompt.type print(module_chat.prompt.__repr__()) # 1 - Si requête = simpe discussion if "discussion" in type_of_request: llm_answer = module_chat.chat.llm_answer llm_answer = tools.translate_from_en(llm_answer, source_lang) with st.chat_message("assistant"): st.write_stream(tools.write_like_chatGPT(llm_answer)) st.session_state.messages[ "msg_{}".format(len(st.session_state.messages)) ] = {"role": "assistant", "content": llm_answer} tools.update_logs(prompt, str(module_chat.prompt), llm_answer) tools.add_feedback_msg_button(prompt, llm_answer) st.session_state.last_response_is_fig = False # 2 - Si requête = visualisation de graph if "visualisation" in type_of_request or "visualization" in type_of_request: with st.spinner("Extracting information...⌛"): # ------------- start energy recording gpu ---------------- energy_meter.start() module_config.prompt_to_config(module_chat.prompt) energy_meter.stop() energy_records.set_gpu(energy_meter) # ------------- stop energy recording gpu ---------------- llm_answer = module_chat.chat.llm_answer config_params = tools.convert_to_dict_config(module_config.config) logger.info(f"missing info --> {module_config.missings}") if not module_config.missings: if len(module_config.config.not_in_shp) > 0: if len(module_config.config.location.split(", ")) == 1: extra = "location" else: extra = "locations" llm_answer = ( f"My apologies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}" f" but I can answer for the other {extra}." ) # -------- 2.1 display msg for config -------- graph_gen = service_graph_generation.ServiceGeneratePlotlyGraph( config_params, tab1_options["langage"], st.session_state.connected_user['user_type'] ) config_params['aggreg_type'] = graph_gen.aggreg_type display_config = f"{llm_answer} \n" + tools.display_config(config_params) + " \n" if config_params['graph_type'] == 'warming stripes': display_config += "The calculations of warming stripes is based on the period 1971-2000." with st.chat_message("assistant"): display_config = tools.translate_from_en(display_config, source_lang) st.write_stream(tools.write_like_chatGPT(display_config)) st.session_state.config_params["msg_{}".format(len(st.session_state.messages))] = config_params st.session_state.messages[ "msg_{}".format(len(st.session_state.messages)) ] = {"role": "assistant", "content": display_config} st.session_state.last_response_is_fig = True # -------- 2.2 display the figure -------- with st.spinner(tools.generate_waiting_for_graph_msg() + "⌛"): graph_gen.generate() msg_plots = list(st.session_state.plots_history.keys()) if msg_plots: msg_id = msg_plots[-1] title, data_csv = tools.add_download_feedback_button( prompt, config_params, st.session_state.plots_history[msg_id], "in_prompt", ) st.session_state.data_history[msg_id] = (title, data_csv) tools.update_logs( prompt, str(module_chat.prompt), tools.from_dict_to_str(config_params) ) elif module_config.missings: LinkSentencesForMissings = [ tools.get_link_question_for_missing(link) for link in module_config.missings ] # TODO naming convention if len(module_config.config.not_in_shp) > 0: if len(module_config.config.not_in_shp) == 1: extra = "another location" else: extra = "other locations" ask_for_complete_request = ( f"My apoligies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}." f" Could you please ask me again for {extra}?" ) to_remove = list( tools.get_link_question_for_missing.link_questions.values() )[2] LinkSentencesForMissings.remove(to_remove) if len(LinkSentencesForMissings) > 0: ask_for_complete_request += f" In addition, could you please precise me {src_tools.enumeration(LinkSentencesForMissings)}." else: ask_for_complete_request = "Thank you for your request. " if config_params['climate_variable'].lower() in parameters.available_vars: ask_for_complete_request += f"I can give you a graphical view of the {config_params['climate_variable']}. " ask_for_complete_request += f"Could you please precise me {src_tools.enumeration(LinkSentencesForMissings)} ?" # module_config.ask_missings can be used to generate a request to the user (_), and another demand used only by the # LLM to determine whether the next prompt provides the missings param, in which case context from previous prompt is needed _ = module_config.ask_missings() with st.chat_message("assistant"): ask_for_complete_request = tools.translate_from_en(ask_for_complete_request, source_lang) st.write_stream(tools.write_like_chatGPT(ask_for_complete_request)) st.session_state.messages[ "msg_{}".format(len(st.session_state.messages)) ] = {"role": "assistant", "content": ask_for_complete_request} st.session_state.last_response_is_fig = False tools.update_logs(prompt, str(module_chat.prompt), ask_for_complete_request) tools.add_feedback_msg_button(prompt, ask_for_complete_request) # final computation of energy recording for the current query cpu_percent = p.cpu_percent() duration = min(time.time() - st.session_state.time_last_call_p, 11) st.session_state.time_last_call_p = time.time() energy_records.set_cpu( energy.get_energy_cpu(cpu_percent=cpu_percent, duration=duration) - energy.get_energy_cpu(cpu_percent=0, duration=duration) ) energy_records.set_ram( energy.get_energy_ram(cpu_percent=cpu_percent, duration=duration) - energy.get_energy_ram(cpu_percent=0, duration=duration) ) st.session_state.energy_consumption = {"query_energy": round(energy_records.query / 1000, 2), "query_CO2": energy_records.get_co2(type_="query"), "query_water": energy_records.get_water(type_="query"), "total_energy": round(energy_records.total / 1000, 2), "total_CO2": energy_records.get_co2(type_="total"), "total_water": energy_records.get_water(type_="total")} st.session_state.history_logging = tools.update_logging_history(st.session_state.history_logging, st.session_state.process_logging.doc(), prompt) st.session_state.log_container.markdown(st.session_state.history_logging) nrg = st.session_state.energy_consumption tools.show_energy_consumption(nrg) st.rerun()