# Copyright 2024 Mews
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
from typing import Union
from diva.chat.chat_object import Chat
from diva.chat.prompt_object import Prompt
from diva.chat.services import ChatbotScopes
from diva.chat.services import Disclaimer
from diva.chat.services import GenChatbotAnswer
from diva.chat.services import IsMemoryNeeded
from diva.chat.services import PromptClassification
from diva.chat.services import PromptRephrasing
from diva.logging.logger import crossing_point
[docs]
class ModuleChat:
"""
This module manages everything related to the chat between the user and chatbot, including
Large Language Model (LLM) interactions.
"""
def __init__(self):
self.__chat = Chat()
self.__memory_needed = False
# ---
self.__classification_strategy = PromptClassification()
self.__chatbot_scope_strategy = ChatbotScopes()
self.__memory_strategy = IsMemoryNeeded()
self.__rephrasing_strategy = PromptRephrasing()
self.__gen_chatbot_answer_strategy = GenChatbotAnswer()
self.__disclaimer_strategy = Disclaimer()
# ---
self.__config_creation_strategy = None
# -
[docs]
def clear_history(self):
""" Clears the chat history"""
self.__chat.clear_history()
self.__config_creation_strategy.clear_local_memory()
# -
[docs]
@crossing_point("Function that classifies prompt intention as either 'discussion' or 'visualisation'")
def prompt_classification(self) -> str:
"""
Classifies the prompt (intention) as either "discussion" or "visualisation". If "discussion", the chatbot
response is only text. If "visualisation", the chatbot response is text + graph. Returns a word,
either "discussion" or "visualisation"
"""
answer = self.__classification_strategy(
user_prompt=str(self.prompt),
scope_strategy=self.__chatbot_scope_strategy,
demand_of_info=self.__chat.demand_of_info,
config_creation_strategy=self.__config_creation_strategy
)
self.prompt.set_type(answer)
last_prompt = None
if len(self.prompt_history) > 1:
# below, -2 because -1 is the current prompt, added to the historic when the Prompt instance is created
last_prompt = self.prompt_history[-2]
self.is_memory_needed(
last_prompt=last_prompt,
prompt=self.prompt,
)
return answer
# -
[docs]
@crossing_point("Function that verifies if the prompt corresponds to a task that is in the scope of the chatbot")
def is_prompt_in_scope(self):
"""
Verifies if the prompt corresponds to a task that is in the scope of the chatbot. If not in scope, the chatbot
may not answer.
"""
last_prompt = None
if len(self.prompt_history) > 1:
# below, -2 because -1 is the current prompt, added to the historic when the Prompt instance is created
last_prompt = self.prompt_history[-2]
self.__chatbot_scope_strategy(prompt=self.prompt, last_prompt=last_prompt)
# -
[docs]
@crossing_point("Function that determines if memory (of the previous answer) is needed")
def is_memory_needed(self, last_prompt: Union['Prompt', None] = None, prompt: Union['Prompt', None] = None):
"""
Determines if memory (of the previous answer or of the previous prompt) is needed to answer the current prompt.
Parameters
----------
last_prompt: Union['Prompt', None], optional, default to None.
the Prompt instance containing the last user prompt. If None, it retrieves the Prompt instance directly
from the chat history
prompt: Union['Prompt', None], optional, default to None.
the Prompt instance, containing the current user prompt. If None, it retrieves the Prompt instance
directly from the chat history.
"""
if last_prompt is None and len(self.prompt_history) > 1:
# below, -2 because -1 is the current prompt (added to the history when the Prompt instance is created)
last_prompt = self.prompt_history[-2]
if prompt is None:
prompt = self.prompt
self.__memory_needed = self.__memory_strategy(
last_prompt=last_prompt,
prompt=prompt,
config_creation_strategy=self.__config_creation_strategy,
last_answer=self.chat.llm_answer
)
prompt.set_memory_needed(self.__memory_needed)
# -
[docs]
@crossing_point("Function that rephrases the user prompt to remove typos, and to make grammatically correct sentence")
def prompt_rephrasing(self):
""" Rephrases the user prompt to remove typos, and to make grammatically correct sentences.
Can also be used to combine the current prompt with previous answer/prompt to mimic memory. """
last_prompt = None
if len(self.prompt_history) > 1:
# below, -2 because -1 is the current prompt, added to the historic when the Prompt instance is created
last_prompt = self.prompt_history[-2]
# ----
self.__rephrasing_strategy(
last_prompt=last_prompt,
prompt=self.prompt,
memory_needed=self.__memory_needed,
last_answer=self.chat.llm_answer
)
if self.prompt.rephrased != self.prompt.original:
self.__config_creation_strategy.clear_local_memory(hard=False)
self.is_prompt_in_scope()
# -
[docs]
def add_disclaimer(self, types: list[str]):
""" Adds warning about frequency of data update and differences between climate predictions and
weather forecast """
disclaimer = self.__disclaimer_strategy.add_disclaimer(types)
self.chat.add_llm_answer(disclaimer)
# -
[docs]
@crossing_point("Function that generates a text answer in response to the user prompt")
def generate_text_answer(self):
"""
Generates a text answer in response to the user prompt.
The answer may vary depending on whether the prompt asks something within the chatbot scopes, depending on
the given context, toxicity of the prompt, type of prompt (visualisation or discussion)...
"""
llm_answer = self.__gen_chatbot_answer_strategy(
prompt=self.prompt,
memory_needed=self.__memory_needed)
self.chat.set_llm_answer(llm_answer)
# -
[docs]
def create_user_prompt(self, prompt: str):
"""
Creates the Prompt instance containing the user prompt. The text string of the user prompt should not be empty.
The Prompt instance is stored in a Chat instance, itself stored in a ModuleChat instance.
Parameters
----------
prompt: str
the text written by the user.
"""
if (type(prompt) is str) and len(prompt) > 0:
self.chat.set_current_prompt(Prompt(prompt))
self.chat.prompt = self.chat.current_prompt
self.__config_creation_strategy.clear_local_memory()
else:
print(
"WARNING: new prompt is either not of type str or is of type str with null length. Current prompt unchanged."
)
# -
# setters
[docs]
def set_demand_of_info(self, text: str):
"""
Setter. Demand of info is the demand made by the chatbot to the user when some information in the prompt are
missing.
"""
if type(text) is str:
self.chat.set_demand_of_info(text)
else:
self.chat.set_demand_of_info(f"{text}")
print(
"Warning: Demand of info set by converting value type to str. The result may be unwanted."
)
# -
[docs]
def set_config_creation_strategy(self, strategy):
""" Can be used to set a new classification strategy"""
self.__config_creation_strategy = strategy
# -
@property
def memory_needed(self) -> bool:
return self.__memory_needed
@property
def chat(self) -> Chat:
return self.__chat
@property
def prompt(self) -> Prompt:
return self.chat.current_prompt
@property
def prompt_history(self) -> list[Prompt]:
return self.chat.prompt_history
@property
def chatbot_scope_strategy(self) -> 'ChatbotScopes':
return self.__chatbot_scope_strategy