# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
"""
Tab 1 that contains the chatbot page
"""
import streamlit as st
import time
import psutil
from pyJoules.energy_meter import EnergyMeter
from loguru import logger
import diva.config as config
import diva.tools as src_tools
import diva.tools as tools
from diva import parameters
from diva import energy
from diva.chat import ModuleChat
from diva.graphs import service_graph_generation
from diva.logging_.logger import Process
[docs]
def get_module_llm():
"""
Imports and returns the default language model module.
This function dynamically imports the `module_llm` from the `llm` package and returns
the default language model (`default_llm`) defined within that module.
Returns:
-------
object
The default language model (`default_llm`) from the `module_llm`.
"""
from diva.llm import llms
return llms.generator
[docs]
def initialize_session_state(module_chat, langage):
"""Initialize all session state variables"""
if "module_config" not in st.session_state:
st.session_state.module_config = config.ModuleConfig(module_chat)
if "p" not in st.session_state:
st.session_state.p = psutil.Process()
st.session_state.p.cpu_percent()
if "time_last_call_p" not in st.session_state:
st.session_state.time_last_call_p = time.time()
if "energy_meter" not in st.session_state:
st.session_state.energy_meter = EnergyMeter(energy.devices)
if "energy_records" not in st.session_state:
st.session_state.energy_records = energy.EnergyRecords()
if "messages" not in st.session_state:
first_sentence = parameters.first_sentence
source_lang = tools.langage_to_iso(langage)
first_sentence = tools.translate_from_en(first_sentence, source_lang)
st.session_state.messages = {
"msg_0": {"role": "assistant", "content": first_sentence}
}
if "plots_history" not in st.session_state:
st.session_state.plots_history = {}
if "config_params" not in st.session_state:
st.session_state.config_params = {}
if "last_response_is_fig" not in st.session_state:
st.session_state.last_response_is_fig = False
if "conv_cleared" not in st.session_state:
st.session_state.conv_cleared = True
if "data_history" not in st.session_state:
st.session_state.data_history = {}
st.session_state.process_logging = Process(text_language='html')
[docs]
def clear_conversation(source_lang):
"""Clear all conversation history and reset state"""
first_sentence = parameters.first_sentence
first_sentence = tools.translate_from_en(first_sentence, source_lang)
st.session_state.messages = {
"msg_0": {"role": "assistant", "content": first_sentence}
}
st.session_state.plots_history = {}
st.session_state.config_params = {}
st.session_state.data_history = {}
st.session_state.last_response_is_fig = False
st.session_state.conv_cleared = True
st.session_state.module_chat = ModuleChat(st.session_state.connected_user['user_type'])
st.session_state.module_config = config.ModuleConfig(st.session_state.module_chat)
st.session_state.process_logging = Process(text_language="html")
st.session_state.history_logging = ''
if st.session_state.history_logging != '':
st.session_state.log_container.markdown(st.session_state.history_logging)
[docs]
def rebuild_chat_history(source_lang):
"""Rebuild and display all messages and figures from history"""
for msg_id in st.session_state.messages.keys():
message = st.session_state.messages[msg_id]
st.chat_message(message["role"]).write(
tools.translate_chat_msgs(message["role"], message["content"], source_lang)
)
if msg_id in st.session_state.plots_history.keys():
fig = st.session_state.plots_history[msg_id]
if isinstance(fig, (parameters.altair_type_fig, parameters.altair_type_fig_concat)):
st.altair_chart(fig, use_container_width=True)
else:
st.plotly_chart(fig, use_container_width=True, key=msg_id+'_plotly_chart')
if msg_id != list(st.session_state.plots_history.keys())[-1]:
title, data_csv = st.session_state.data_history[msg_id]
tools.add_download_button_after_prompts(
title, data_csv, message,
st.session_state.config_params[msg_id],
key=msg_id+'_download_button'
)
# Display download & feedback buttons for last figure
if st.session_state.config_params != {} and st.session_state.last_response_is_fig:
msg_plots = list(st.session_state.plots_history.keys())
msg_prompt = list(st.session_state.messages.keys())
if msg_plots != []:
msg_id = msg_plots[-1]
msg_id_ = msg_prompt[-1]
tools.add_download_feedback_button(
st.session_state.messages[msg_id_],
st.session_state.config_params[msg_id],
st.session_state.plots_history[msg_id],
)
[docs]
def handle_discussion(module_chat, prompt, source_lang):
"""Handle simple discussion requests"""
llm_answer = module_chat.chat.llm_answer
llm_answer = tools.translate_from_en(llm_answer, source_lang)
with st.chat_message("assistant"):
st.write_stream(tools.write_like_chatGPT(llm_answer))
st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = {
"role": "assistant",
"content": llm_answer
}
tools.update_logs(prompt, str(module_chat.prompt), llm_answer)
tools.add_feedback_msg_button(prompt, llm_answer)
st.session_state.last_response_is_fig = False
[docs]
def handle_visualization(module_chat, module_config, prompt, source_lang, energy_meter, energy_records, user_type, langage):
"""Handle visualization requests"""
with st.spinner("Extracting information...⌛"):
energy_meter.start()
module_config.prompt_to_config(module_chat.prompt)
energy_meter.stop()
energy_records.set_gpu(energy_meter)
llm_answer = module_chat.chat.llm_answer
config_params = tools.convert_to_dict_config(module_config.config)
logger.info(f"missing info --> {module_config.missings}")
if not module_config.missings:
_handle_complete_visualization(
module_chat, module_config, config_params, llm_answer,
prompt, source_lang, user_type, langage
)
else:
_handle_missing_params(
module_chat, module_config, config_params, prompt, source_lang
)
def _handle_complete_visualization(module_chat, module_config, config_params, llm_answer, prompt, source_lang, user_type, langage):
"""Handle visualization when all parameters are present"""
if len(module_config.config.not_in_shp) > 0:
extra = "location" if len(module_config.config.location.split(", ")) == 1 else "locations"
llm_answer = (
f"My apologies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}"
f" but I can answer for the other {extra}."
)
graph_gen = service_graph_generation.ServiceGeneratePlotlyGraph(
config_params, langage, user_type
)
config_params['aggreg_type'] = graph_gen.aggreg_type
display_config = f"{llm_answer} \n" + tools.display_config(config_params) + " \n"
if config_params['graph_type'] == 'warming stripes':
display_config += "The calculations of warming stripes is based on the period 1971-2000."
with st.chat_message("assistant"):
display_config = tools.translate_from_en(display_config, source_lang)
st.write_stream(tools.write_like_chatGPT(display_config))
st.session_state.config_params[f"msg_{len(st.session_state.messages)}"] = config_params
st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = {
"role": "assistant",
"content": display_config
}
st.session_state.last_response_is_fig = True
with st.spinner(tools.generate_waiting_for_graph_msg() + "⌛"):
graph_gen.generate()
msg_plots = list(st.session_state.plots_history.keys())
if msg_plots:
msg_id = msg_plots[-1]
title, data_csv = tools.add_download_feedback_button(
prompt, config_params,
st.session_state.plots_history[msg_id],
"in_prompt",
)
st.session_state.data_history[msg_id] = (title, data_csv)
tools.update_logs(prompt, str(module_chat.prompt), tools.from_dict_to_str(config_params))
def _handle_missing_params(module_chat,module_config, config_params, prompt, source_lang):
"""Handle visualization when some parameters are missing"""
link_sentences_for_missings = [
tools.get_link_question_for_missing(link)
for link in module_config.missings
]
if len(module_config.config.not_in_shp) > 0:
extra = "another location" if len(module_config.config.not_in_shp) == 1 else "other locations"
ask_for_complete_request = (
f"My apologies, I don't know {src_tools.enumeration(module_config.config.not_in_shp)}."
f" Could you please ask me again for {extra}?"
)
to_remove = list(tools.get_link_question_for_missing.link_questions.values())[2]
link_sentences_for_missings.remove(to_remove)
if len(link_sentences_for_missings) > 0:
ask_for_complete_request += f" In addition, could you please precise me {src_tools.enumeration(link_sentences_for_missings)}."
else:
ask_for_complete_request = "Thank you for your request. "
if config_params['climate_variable'].lower() in parameters.available_vars:
ask_for_complete_request += f"I can give you a graphical view of the {config_params['climate_variable']}. "
ask_for_complete_request += f"Could you please precise me {src_tools.enumeration(link_sentences_for_missings)} ?"
_ = module_config.ask_missings()
with st.chat_message("assistant"):
ask_for_complete_request = tools.translate_from_en(ask_for_complete_request, source_lang)
st.write_stream(tools.write_like_chatGPT(ask_for_complete_request))
st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = {
"role": "assistant",
"content": ask_for_complete_request
}
st.session_state.last_response_is_fig = False
tools.update_logs(prompt, str(module_chat.prompt), ask_for_complete_request)
tools.add_feedback_msg_button(prompt, ask_for_complete_request)
[docs]
def process_prompt(module_chat, module_config, prompt, source_lang, energy_meter, energy_records, p, user_type,langage):
"""Process user prompt and generate appropriate response"""
tools.send_user_event("prompt")
prompt = tools.clean_prompt(prompt)
# Energy recording setup
energy_records.clear_query()
cpu_percent = p.cpu_percent()
st.session_state.time_last_call_p = time.time()
st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = {
"role": "user",
"content": prompt,
}
st.session_state.conv_cleared = False
with st.chat_message("user"):
st.markdown(prompt)
prompt = tools.translate_to_en(prompt, source_lang)
with st.spinner("Processing the request...⌛"):
module_chat.create_user_prompt(prompt)
assert module_chat.prompt is not None
energy_meter.start()
module_chat.prompt_classification().lower()
module_chat.prompt_rephrasing()
module_chat.is_prompt_in_scope()
module_chat.generate_text_answer()
energy_meter.stop()
energy_records.set_gpu(energy_meter)
type_of_request = module_chat.prompt.type
print(module_chat.prompt.__repr__())
if "discussion" in type_of_request:
handle_discussion(module_chat, prompt, source_lang)
if "visualisation" in type_of_request or "visualization" in type_of_request:
handle_visualization(
module_chat, module_config, prompt, source_lang,
energy_meter, energy_records, user_type, langage
)
# Finalize energy recording
cpu_percent = p.cpu_percent()
duration = min(time.time() - st.session_state.time_last_call_p, 11)
st.session_state.time_last_call_p = time.time()
energy_records.set_cpu(
energy.get_energy_cpu(cpu_percent=cpu_percent, duration=duration) -
energy.get_energy_cpu(cpu_percent=0, duration=duration)
)
energy_records.set_ram(
energy.get_energy_ram(cpu_percent=cpu_percent, duration=duration) -
energy.get_energy_ram(cpu_percent=0, duration=duration)
)
st.session_state.energy_consumption = {
"query_energy": round(energy_records.query / 1000, 2),
"query_CO2": energy_records.get_co2(type_="query"),
"query_water": energy_records.get_water(type_="query"),
"total_energy": round(energy_records.total / 1000, 2),
"total_CO2": energy_records.get_co2(type_="total"),
"total_water": energy_records.get_water(type_="total")
}
st.session_state.history_logging = tools.update_logging_history(
st.session_state.history_logging,
st.session_state.process_logging.doc(),
prompt
)
st.session_state.log_container.markdown(st.session_state.history_logging)
nrg = st.session_state.energy_consumption
tools.show_energy_consumption(nrg)
[docs]
def main(tab1_options):
"""Main function for Tab 1 - refactored version"""
# Initialize module_chat
if "module_chat" not in st.session_state:
st.session_state.module_chat = ModuleChat(st.session_state.connected_user['user_type'])
module_chat = st.session_state.module_chat
langage = tab1_options["langage"]
# Initialize all session state variables
initialize_session_state(module_chat, langage)
source_lang = tools.langage_to_iso(tab1_options["langage"])
user_type = st.session_state.connected_user['user_type']
module_config = st.session_state.module_config
energy_meter = st.session_state.energy_meter
energy_records = st.session_state.energy_records
p = st.session_state.p
# Handle conversation clearing
if tab1_options["clear_conv"]:
print("CLEAR CONVERSATION")
clear_conversation(source_lang)
st.rerun()
# Rebuild chat history
rebuild_chat_history(source_lang)
# Get user input
prompt = get_user_input(source_lang, tab1_options['dev_mode'])
# Check for empty prompt
if prompt:
prompt_is_empty = tools.chech_if_empty_prompt(prompt)
if prompt_is_empty:
prompt = False
st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = {
"role": "user",
"content": ""
}
with st.chat_message("user"):
st.markdown("")
with st.chat_message("assistant"):
answer = "Please write something ..."
st.write_stream(tools.write_like_chatGPT(answer))
st.session_state.messages[f"msg_{len(st.session_state.messages)}"] = {
"role": "assistant",
"content": answer
}
# Process the prompt if valid
if prompt:
process_prompt(
module_chat, module_config, prompt, source_lang,
energy_meter, energy_records, p,
user_type, langage
)
st.rerun()