|
from copy import deepcopy |
|
import hydra |
|
|
|
import time |
|
|
|
from typing import Dict, Optional, Any |
|
|
|
from aiflows.base_flows import AtomicFlow |
|
|
|
from aiflows.utils import logging |
|
from aiflows.messages.flow_message import UpdateMessage_ChatMessage |
|
|
|
from aiflows.prompt_template import JinjaPrompt |
|
|
|
from aiflows.backends.llm_lite import LiteLLMBackend |
|
|
|
from aiflows.messages import FlowMessage |
|
|
|
log = logging.get_logger(__name__) |
|
|
|
|
|
class ChatAtomicFlow(AtomicFlow): |
|
""" This class implements a ChatAtomicFlow. It is a flow that uses a LLM via an API to generate textuals responses to textual inputs. |
|
It employs litellm as a backend to query the LLM via an API. See litellm's supported models and APIs here: https://docs.litellm.ai/docs/providers |
|
|
|
*Configuration Parameters*: |
|
|
|
- `name` (str): The name of the flow. This name is used to identify the flow in the logs. Default: "ChatAtomicFlow" |
|
- `description` (str): A description of the flow. This description is used to generate the help message of the flow. |
|
Default: "Flow which uses as tool an LLM though an API" |
|
- `enable_cache` (bool): Whether to enable cache for this flow. Default: True |
|
- `n_api_retries` (int): The number of times to retry the API call in case of failure. Default: 6 |
|
- `wait_time_between_retries` (int): The number of seconds to wait between each API call when the call fails. Default: 20 |
|
- `system_name` (str): The name of the system (roles of LLM). Default: "system" |
|
- `user_name` (str): The name of the user (roles of LLM). Default: "user" |
|
- `assistant_name` (str): The name of the assistant (roles of LLM). Default: "assistant" |
|
- `backend` Dict[str,Any]: The backend of the flow. Used to call models via an API. |
|
See litellm's supported models and APIs here: https://docs.litellm.ai/docs/providers. |
|
The default parameters of the backend are all defined at aiflows.backends.llm_lite.LiteLLMBackend |
|
(also see the defaul parameters of litellm's completion parameters: https://docs.litellm.ai/docs/completion/input#input-params-1). |
|
Except for the following parameters who are overwritten by the ChatAtomicFlow in ChatAtomicFlow.yaml: |
|
- `model_name` (Union[Dict[str,str],str]): The name of the model to use. |
|
When using multiple API providers, the model_name can be a dictionary of the form |
|
{"provider_name": "model_name"}. E.g. {"openai": "gpt-3.5-turbo", "azure": "azure/gpt-3.5-turbo"} |
|
Default: "gpt-3.5-turbo" (the name needs to follow the name of the model in litellm https://docs.litellm.ai/docs/providers). |
|
- `n` (int) : The number of answers to generate. Default: 1 |
|
- `max_tokens` (int): The maximum number of tokens to generate. Default: 2000 |
|
- `temperature` float: The temperature of the generation. Default: 0.3 |
|
- `top_p` float: An alternative to sampling with temperature. It instructs the model to consider the results of |
|
the tokens with top_p probability. Default: 0.2 |
|
- `frequency_penalty` (number): It is used to penalize new tokens based on their frequency in the text so far. Default: 0.0 |
|
- `presence_penalty` (number): It is used to penalize new tokens based on their existence in the text so far. Default: 0.0 |
|
- `stream` (bool): Whether to stream the response or not. Default: True |
|
- `system_message_prompt_template` (Dict[str,Any]): The template of the system message. It is used to generate the system message. |
|
By default its of type aiflows.prompt_template.JinjaPrompt. |
|
None of the parameters of the prompt are defined by default and therefore need to be defined if one wants to use the system prompt. |
|
Default parameters are defined in aiflows.prompt_template.jinja2_prompts.JinjaPrompt. |
|
- `init_human_message_prompt_template` (Dict[str,Any]): The prompt template of the human/user message used to initialize the conversation |
|
(first time in). It is used to generate the human message. It's passed as the user message to the LLM. |
|
By default its of type aiflows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one |
|
wants to use the init_human_message_prompt_template. Default parameters are defined in aiflows.prompt_template.jinja2_prompts.JinjaPrompt. |
|
- `human_message_prompt_template` (Dict[str,Any]): The prompt template of the human/user message (message used everytime the except the first time in). |
|
It's passed as the user message to the LLM. By default its of type aiflows.prompt_template.JinjaPrompt and has the following parameters: |
|
- `template` (str): The template of the human message. Default: see ChatAtomicFlow.yaml for the default value. |
|
- `input_variables` (List[str]): The input variables of the human message prompt template. Default: ["query"] |
|
- `previous_messages` (Dict[str,Any]): Defines which previous messages to include in the input of the LLM. Note that if `first_k`and `last_k` are both none, |
|
all the messages of the flows's history are added to the input of the LLM. Default: |
|
- `first_k` (int): If defined, adds the first_k earliest messages of the flow's chat history to the input of the LLM. Default: None |
|
- `last_k` (int): If defined, adds the last_k latest messages of the flow's chat history to the input of the LLM. Default: None |
|
|
|
|
|
*Input Interface Initialized (Expected input the first time in flow)*: |
|
|
|
- `query` (str): The query given to the flow. (e.g. {"query":"What's the capital of Switzerland?"}) |
|
|
|
*Input Interface (Expected input the after the first time in flow)*: |
|
|
|
- `query` (str): The query given to the flow. (e.g. {"query": "Are you sure of your answer?"}) |
|
|
|
*Output Interface*: |
|
|
|
- `api_output` (str): The output of the API call. It is the response of the LLM to the input. (e.g. {"api_output": "The Capital of Switzerland is Bern"}) |
|
|
|
:param system_message_prompt_template: The template of the system message. It is used to generate the system message. |
|
:type system_message_prompt_template: JinjaPrompt |
|
:param human_message_prompt_template: The template of the human message. It is used to generate the human message. |
|
:type human_message_prompt_template: JinjaPrompt |
|
:param init_human_message_prompt_template: The template of the human message that is used to initialize the conversation (first time in). It is used to generate the human message. |
|
:type init_human_message_prompt_template: Optional[JinjaPrompt] |
|
:param backend: The backend of the flow. It is a LLM that is queried via an API. See litellm's supported models and APIs here: https://docs.litellm.ai/docs/providers |
|
:type backend: LiteLLMBackend |
|
:param \**kwargs: Additional arguments to pass to the flow. See :class:`aiflows.base_flows.AtomicFlow` for more details. |
|
""" |
|
REQUIRED_KEYS_CONFIG = ["backend"] |
|
|
|
SUPPORTS_CACHING: bool = True |
|
|
|
system_message_prompt_template: JinjaPrompt |
|
human_message_prompt_template: JinjaPrompt |
|
|
|
backend: LiteLLMBackend |
|
init_human_message_prompt_template: Optional[JinjaPrompt] = None |
|
|
|
def __init__(self, |
|
system_message_prompt_template, |
|
human_message_prompt_template, |
|
init_human_message_prompt_template, |
|
backend, |
|
**kwargs): |
|
super().__init__(**kwargs) |
|
self.system_message_prompt_template = system_message_prompt_template |
|
self.human_message_prompt_template = human_message_prompt_template |
|
self.init_human_message_prompt_template = init_human_message_prompt_template |
|
self.backend = backend |
|
assert self.flow_config["name"] not in [ |
|
"system", |
|
"user", |
|
"assistant", |
|
], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" |
|
|
|
def set_up_flow_state(self): |
|
""" This method sets up the state of the flow and clears the previous messages.""" |
|
super().set_up_flow_state() |
|
self.flow_state["previous_messages"] = [] |
|
|
|
@classmethod |
|
def _set_up_prompts(cls, config): |
|
""" This method sets up the prompts of the flow. It instantiates the prompt templates. |
|
|
|
:param config: The configuration of the flow. |
|
:type config: Dict[str, Any] |
|
:return: The instantiated prompts of the flow. |
|
:rtype: Dict[str, Any] |
|
""" |
|
kwargs = {} |
|
|
|
kwargs["system_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") |
|
kwargs["init_human_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['init_human_message_prompt_template'], _convert_="partial") |
|
kwargs["human_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") |
|
|
|
return kwargs |
|
|
|
@classmethod |
|
def _set_up_backend(cls, config): |
|
""" This method sets up the backend of the flow. It instantiates the backend. |
|
|
|
:param config: The configuration of the flow. |
|
:type config: Dict[str, Any] |
|
:return: The instantiated backend of the flow. |
|
:rtype: Dict[str, Any] |
|
""" |
|
kwargs = {} |
|
|
|
kwargs["backend"] = \ |
|
hydra.utils.instantiate(config['backend'], _convert_="partial") |
|
|
|
return kwargs |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config): |
|
""" This method instantiates the flow from a configuration file |
|
|
|
:param config: The configuration of the flow. |
|
:type config: Dict[str, Any] |
|
:return: The instantiated flow. |
|
:rtype: ChatAtomicFlow |
|
""" |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
kwargs.update(cls._set_up_backend(flow_config)) |
|
|
|
|
|
return cls(**kwargs) |
|
|
|
def _is_conversation_initialized(self): |
|
""" This method checks if the conversation is initialized or not. |
|
|
|
:return: True if the conversation is initialized, False otherwise. |
|
:rtype: bool |
|
""" |
|
if len(self.flow_state["previous_messages"]) > 0: |
|
return True |
|
|
|
return False |
|
|
|
def get_interface_description(self): |
|
""" This method returns the description of the flow's input and output interface. |
|
|
|
:return: The description of the flow's interface. |
|
:rtype: Dict[str, Any] |
|
""" |
|
if self._is_conversation_initialized(): |
|
|
|
return {"input": self.flow_config["input_interface_initialized"], |
|
"output": self.flow_config["output_interface"]} |
|
else: |
|
return {"input": self.flow_config["input_interface_non_initialized"], |
|
"output": self.flow_config["output_interface"]} |
|
|
|
@staticmethod |
|
def _get_message(prompt_template, input_data: Dict[str, Any]): |
|
""" This method generates a message using a prompt template and input data which should contain the input variables of the prompt template. |
|
|
|
:param prompt_template: The prompt template. |
|
:type prompt_template: JinjaPrompt |
|
:param input_data: The input data of the prompt template. |
|
:type input_data: Dict[str, Any] |
|
:return: The generated message. |
|
""" |
|
template_kwargs = {} |
|
for input_variable in prompt_template.input_variables: |
|
template_kwargs[input_variable] = input_data[input_variable] |
|
msg_content = prompt_template.format(**template_kwargs) |
|
return msg_content |
|
|
|
def _add_demonstrations(self, input_data: Dict[str, Any]): |
|
""" This method adds demonstrations to the flow (If any). The demonstrations should be passed from a DemonstrationFlow |
|
|
|
:param input_data: The input data of the flow. |
|
:type input_data: Dict[str, Any] |
|
""" |
|
for demonstration in input_data.get("demonstrations",[]): |
|
query = demonstration["query"] |
|
response = demonstration["response"] |
|
|
|
self._state_update_add_chat_message(content=query, |
|
role=self.flow_config["user_name"]) |
|
|
|
self._state_update_add_chat_message(content=response, |
|
role=self.flow_config["assistant_name"]) |
|
|
|
def _state_update_add_chat_message(self, |
|
role: str, |
|
content: str) -> None: |
|
""" This method adds a message to the flow's state. |
|
|
|
:param role: The role of the message (e.g. "user", "assistant", "system"). |
|
:type role: str |
|
:param content: The content of the message. |
|
:type content: str |
|
""" |
|
|
|
acceptable_roles = [self.flow_config["system_name"],self.flow_config["user_name"],self.flow_config["assistant_name"]] |
|
if role in acceptable_roles: |
|
self.flow_state["previous_messages"].append({"role": role , "content": content}) |
|
|
|
else: |
|
raise Exception(f"Invalid role: `{role}`.\n" |
|
f"Role should be one of: " |
|
f"`{acceptable_roles}`, ") |
|
|
|
|
|
|
|
chat_message = UpdateMessage_ChatMessage( |
|
created_by=self.flow_config["name"], |
|
updated_flow=self.flow_config["name"], |
|
role=role, |
|
content=content, |
|
) |
|
self._log_message(chat_message) |
|
|
|
def _get_previous_messages(self): |
|
""" This method returns the previous messages of the flow. If first_k is set, it returns the first k messages. If last_k is set, it returns the last k messages. |
|
If both are set, it returns the first k messages and the last k messages. If none are set, it returns all the messages. |
|
|
|
:return: The previous messages of the flow. |
|
:rtype: List[Dict[str, Any]] |
|
""" |
|
all_messages = self.flow_state["previous_messages"] |
|
first_k = self.flow_config["previous_messages"]["first_k"] |
|
last_k = self.flow_config["previous_messages"]["last_k"] |
|
|
|
if not first_k and not last_k: |
|
return all_messages |
|
elif first_k and last_k: |
|
return all_messages[:first_k] + all_messages[-last_k:] |
|
elif first_k: |
|
return all_messages[:first_k] |
|
return all_messages[-last_k:] |
|
|
|
def _call(self): |
|
""" This method calls the backend of the flow (so queries the LLM). It calls the backend with the previous messages of the flow. |
|
|
|
:return: The output of the backend. |
|
:rtype: Any |
|
""" |
|
messages = self._get_previous_messages() |
|
_success = False |
|
attempts = 1 |
|
error = None |
|
response = None |
|
while attempts <= self.flow_config['n_api_retries']: |
|
try: |
|
response = self.backend(messages=messages,mock_response=False) |
|
response = [ answer["content"] for answer in response] |
|
_success = True |
|
break |
|
except Exception as e: |
|
log.error(f"Error {attempts} in calling backend: {e}. ") |
|
log.error(f"Retrying in {self.flow_config['wait_time_between_retries']} seconds...") |
|
|
|
|
|
|
|
|
|
attempts += 1 |
|
time.sleep(self.flow_config['wait_time_between_retries']) |
|
error = e |
|
|
|
if not _success: |
|
raise error |
|
|
|
return response |
|
|
|
def _initialize_conversation(self, input_data: Dict[str, Any]): |
|
""" This method initializes the conversation (occurs the first time in). It adds the system message and potentially the demonstrations (if any). |
|
|
|
:param input_data: The input data of the flow. |
|
:type input_data: Dict[str, Any] |
|
""" |
|
|
|
system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
|
|
|
self._state_update_add_chat_message(content=system_message_content, |
|
role=self.flow_config["system_name"]) |
|
|
|
|
|
self._add_demonstrations(input_data) |
|
|
|
def _process_input(self, input_data: Dict[str, Any]): |
|
""" This method processes the input of the flow. It adds the human message to the flow's state. If the conversation is not initialized, it also initializes it |
|
(adding the system message and potentially the demonstrations). |
|
|
|
:param input_data: The input data of the flow. |
|
:type input_data: Dict[str, Any] |
|
""" |
|
if self._is_conversation_initialized(): |
|
|
|
user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
|
|
|
else: |
|
|
|
self._initialize_conversation(input_data) |
|
if getattr(self, "init_human_message_prompt_template", None) is not None: |
|
|
|
user_message_content = self._get_message(self.init_human_message_prompt_template, input_data) |
|
else: |
|
user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
|
|
|
self._state_update_add_chat_message(role=self.flow_config["user_name"], |
|
content=user_message_content) |
|
|
|
def query_llm(self,input_data: Dict[str, Any]): |
|
""" This method queries the LLM. It processes the input, calls the backend and returns the response. |
|
|
|
:param input_data: The input data of the flow. |
|
:type input_data: Dict[str, Any] |
|
:return: The response of the LLM to the input. |
|
:rtype: Union[str, List[str]] |
|
""" |
|
|
|
|
|
|
|
self._process_input(input_data) |
|
|
|
|
|
response = self._call() |
|
|
|
|
|
for answer in response: |
|
self._state_update_add_chat_message( |
|
role=self.flow_config["assistant_name"], |
|
content=answer |
|
) |
|
response = response if len(response) > 1 or len(response) == 0 else response[0] |
|
return response |
|
|
|
def run(self,input_message: FlowMessage): |
|
""" This method runs the flow. It processes the input, calls the backend and updates the state of the flow. |
|
|
|
:param input_message: The input data of the flow. |
|
:type input_message: aiflows.messages.FlowMessage |
|
""" |
|
input_data = input_message.data |
|
|
|
response = self.query_llm(input_data=input_data) |
|
|
|
|
|
reply_message = self.package_output_message( |
|
input_message, |
|
response = {"api_output": response} |
|
) |
|
|
|
self.send_message(reply_message) |
|
|