from copy import deepcopy import hydra import time from typing import List, Dict, Optional, Any from langchain import PromptTemplate import langchain from langchain.schema import HumanMessage, AIMessage, SystemMessage from flows.message_annotators.abstract import MessageAnnotator from flows.base_flows.abstract import AtomicFlow from flows.datasets import GenericDemonstrationsDataset from flows import utils from flows.messages.flow_message import UpdateMessage_ChatMessage from flows.utils.caching_utils import flow_run_cache from flows.utils.general_helpers import validate_parameters log = utils.get_pylogger(__name__) # ToDo: Add support for demonstrations class OpenAIChatAtomicFlow(AtomicFlow): REQUIRED_KEYS_CONFIG = ["model_name", "generation_parameters"] REQUIRED_KEYS_KWARGS = ["system_message_prompt_template", "human_message_prompt_template", "query_message_prompt_template"] SUPPORTS_CACHING: bool = True api_keys: Dict[str, str] system_message_prompt_template: PromptTemplate human_message_prompt_template: PromptTemplate query_message_prompt_template: Optional[PromptTemplate] = None demonstrations: GenericDemonstrationsDataset = None demonstrations_response_template: PromptTemplate = None response_annotators: Optional[Dict[str, MessageAnnotator]] = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.api_keys = None assert self.flow_config["name"] not in [ "system", "user", "assistant", ], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" def set_up_flow_state(self): super().set_up_flow_state() self.flow_state["previous_messages"] = [] @classmethod def _validate_parameters(cls, kwargs): validate_parameters(cls, kwargs) @classmethod def _set_up_prompts(cls, config): kwargs = {} kwargs["system_message_prompt_template"] = \ hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") kwargs["query_message_prompt_template"] = \ hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial") kwargs["human_message_prompt_template"] = \ hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") return kwargs # @classmethod # def _set_up_demonstration_templates(cls, config): # kwargs = {} # # if "demonstrations_response_template" in config: # kwargs["demonstrations_response_template"] = \ # hydra.utils.instantiate(config['demonstrations_response_template'], _convert_="partial") # # return kwargs @classmethod def _set_up_response_annotators(cls, config): response_annotators = config.get("response_annotators", {}) response_annotators = deepcopy(response_annotators) if len(response_annotators) > 0: for key, config in response_annotators.items(): response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial") return {"response_annotators": response_annotators} @classmethod def instantiate_from_config(cls, config): flow_config = deepcopy(config) kwargs = {"flow_config": flow_config} # ~~~ Set up prompts ~~~ kwargs.update(cls._set_up_prompts(flow_config)) # # ~~~ Set up demonstration templates ~~~ # kwargs.update(cls._set_up_demonstration_templates(flow_config)) # ~~~ Set up response annotators ~~~ kwargs.update(cls._set_up_response_annotators(flow_config)) # ~~~ Instantiate flow ~~~ return cls(**kwargs) def _is_conversation_initialized(self): if len(self.flow_state["previous_messages"]) > 0: return True return False def get_expected_inputs(self, data: Optional[Dict[str, Any]] = None): """Returns the expected inputs for the flow given the current state and, optionally, the input data""" if self._is_conversation_initialized(): return ["query"] else: return self.flow_config["expected_inputs"] @staticmethod def _get_message(prompt_template, input_data: Dict[str, Any]): template_kwargs = {} for input_variable in prompt_template.input_variables: template_kwargs[input_variable] = input_data[input_variable] msg_content = prompt_template.format(**template_kwargs) return msg_content # def _get_demonstration_query_message_content(self, sample_data: Dict): # input_variables = self.query_message_prompt_template.input_variables # return self.query_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}), [] # # def _get_demonstration_response_message_content(self, sample_data: Dict): # input_variables = self.demonstrations_response_template.input_variables # return self.demonstrations_response_template.format(**{k: sample_data[k] for k in input_variables}), [] def _get_annotator_with_key(self, key: str): for _, ra in self.response_annotators.items(): if ra.key == key: return ra def _response_parsing(self, response: str, expected_outputs: List[str]): target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs] parsed_outputs = {} for ra in target_annotators: parsed_out = ra(response) parsed_outputs.update(parsed_out) if "raw_response" in expected_outputs: parsed_outputs["raw_response"] = response else: log.warning("The raw response is not logged because it was not requested as per the expected output.") if len(parsed_outputs) == 0: raise Exception(f"The output dictionary is empty. " f"None of the expected outputs: `{str(expected_outputs)}` were found.") return parsed_outputs # def _add_demonstrations(self): # if self.demonstrations is not None: # for example in self.demonstrations: # query, parents = self._get_demonstration_query_message_content(example) # response, parents = self._get_demonstration_response_message_content(example) # # self._log_chat_message(content=query, # role=self.user_name, # parent_message_ids=parents) # # self._log_chat_message(content=response, # role=self.assistant_name, # parent_message_ids=parents) def _state_update_add_chat_message(self, role: str, content: str) -> None: # Add the message to the previous messages list if role == self.flow_config["system_name"]: self.flow_state["previous_messages"].append(SystemMessage(content=content)) elif role == self.flow_config["user_name"]: self.flow_state["previous_messages"].append(HumanMessage(content=content)) elif role == self.flow_config["assistant_name"]: self.flow_state["previous_messages"].append(AIMessage(content=content)) else: raise Exception(f"Invalid role: `{role}`.\n" f"Role should be one of: " f"`{self.flow_config['system_name']}`, " f"`{self.flow_config['user_name']}`, " f"`{self.flow_config['assistant_name']}`") # Log the update to the flow messages list chat_message = UpdateMessage_ChatMessage( created_by=self.flow_config["name"], updated_flow=self.flow_config["name"], role=role, content=content, ) self._log_message(chat_message) def _call(self): api_key = self.api_keys["openai"] backend = langchain.chat_models.ChatOpenAI( model_name=self.flow_config["model_name"], openai_api_key=api_key, **self.flow_config["generation_parameters"], ) messages = self.flow_state["previous_messages"] _success = False attempts = 1 error = None response = None while attempts <= self.flow_config['n_api_retries']: try: response = backend(messages).content _success = True break except Exception as e: log.error( f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. " f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..." ) # log.error( # f"The API call raised an exception with the following arguments: " # f"\n{self.flow_state['history'].to_string()}" # ) # ToDo: Make this message more user-friendly attempts += 1 time.sleep(self.flow_config['wait_time_between_retries']) error = e if not _success: raise error return response def _initialize_conversation(self, input_data: Dict[str, Any]): # ~~~ Add the system message ~~~ system_message_content = self._get_message(self.system_message_prompt_template, input_data) self._state_update_add_chat_message(content=system_message_content, role=self.flow_config["system_name"]) # # ~~~ Add the demonstration query-response tuples (if any) ~~~ # self._add_demonstrations() # self._update_state(update_data={"conversation_initialized": True}) def _process_input(self, input_data: Dict[str, Any]): if self._is_conversation_initialized(): # Construct the message using the human message prompt template user_message_content = self._get_message(self.human_message_prompt_template, input_data) else: # Initialize the conversation (add the system message, and potentially the demonstrations) self._initialize_conversation(input_data) # Construct the message using the query message prompt template user_message_content = self._get_message(self.query_message_prompt_template, input_data) self._state_update_add_chat_message(role=self.flow_config["user_name"], content=user_message_content) @flow_run_cache() def run(self, input_data: Dict[str, Any], private_keys: Optional[List[str]] = [], keys_to_ignore_for_hash: Optional[List[str]] = []) -> Dict[str, Any]: self.api_keys = input_data["api_keys"] del input_data["api_keys"] # ~~~ Process input ~~~ self._process_input(input_data) # ~~~ Call ~~~ response = self._call() self._state_update_add_chat_message( role=self.flow_config["assistant_name"], content=response ) # ~~~ Response parsing ~~~ output_data = self._response_parsing( response=response, expected_outputs=input_data["expected_outputs"] ) # self._state_update_dict(update_data=output_data) # ToDo: Is this necessary? When? return output_data