import pprint from copy import deepcopy import hydra import colorama import time from typing import List, Dict, Optional, Any from langchain import PromptTemplate import langchain from langchain.schema import HumanMessage, AIMessage, SystemMessage from flows.history import FlowHistory from flows.message_annotators.abstract import MessageAnnotator from flows.base_flows.abstract import AtomicFlow from flows.datasets import GenericDemonstrationsDataset from flows import utils from flows.messages.chat_message import ChatMessage from flows.utils.caching_utils import flow_run_cache log = utils.get_pylogger(__name__) class OpenAIChatAtomicFlow(AtomicFlow): model_name: str generation_parameters: Dict system_message_prompt_template: PromptTemplate human_message_prompt_template: PromptTemplate system_name: str = "system" user_name: str = "user" assistant_name: str = "assistant" n_api_retries: int = 6 wait_time_between_retries: int = 20 query_message_prompt_template: Optional[PromptTemplate] = None demonstrations: GenericDemonstrationsDataset = None demonstrations_response_template: PromptTemplate = None response_annotators: Optional[Dict[str, MessageAnnotator]] = {} def __init__(self, **kwargs): self._validate_parameters(kwargs) super().__init__(**kwargs) assert self.flow_config["name"] not in [ "system", "user", "assistant", ], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" def set_up_flow_state(self): super().set_up_flow_state() self.flow_state["conversation_initialized"] = False @classmethod def _validate_parameters(cls, kwargs): # ToDo: Deal with this in a cleaner way (with less repetition) super()._validate_parameters(kwargs) # ~~~ Model generation ~~~ if "model_name" not in kwargs["flow_config"]: raise KeyError("model_name not specified in the flow_config.") if "generation_parameters" not in kwargs["flow_config"]: raise KeyError("generation_parameters not specified in the flow_config.") # ~~~ Prompting ~~~ if "system_message_prompt_template" not in kwargs: raise KeyError("system_message_prompt_template not passed to the constructor.") if "query_message_prompt_template" not in kwargs: raise KeyError("query_message_prompt_template not passed to the constructor.") if "human_message_prompt_template" not in kwargs: raise KeyError("human_message_prompt_template not passed to the constructor.") @classmethod def _set_up_prompts(cls, config): kwargs = {} kwargs["system_message_prompt_template"] = \ hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") kwargs["query_message_prompt_template"] = \ hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial") kwargs["human_message_prompt_template"] = \ hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") return kwargs @classmethod def _set_up_demonstration_templates(cls, config): kwargs = {} if "demonstrations_response_template" in config: kwargs["demonstrations_response_template"] = \ hydra.utils.instantiate(config['demonstrations_response_template'], _convert_="partial") return kwargs @classmethod def _set_up_response_annotators(cls, config): response_annotators = config.get("response_annotators", {}) if len(response_annotators) > 0: for key, config in response_annotators.items(): response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial") return {"response_annotators": response_annotators} @classmethod def instantiate_from_config(cls, config): flow_config = deepcopy(config) kwargs = {"flow_config": flow_config} # ~~~ Set up prompts ~~~ kwargs.update(cls._set_up_prompts(flow_config)) # ~~~ Set up demonstration templates ~~~ kwargs.update(cls._set_up_demonstration_templates(flow_config)) # ~~~ Set up response annotators ~~~ kwargs.update(cls._set_up_response_annotators(flow_config)) # ~~~ Instantiate flow ~~~ return cls(**kwargs) def _is_conversation_initialized(self): return self.flow_state["conversation_initialized"] def expected_inputs_given_state(self): if self._is_conversation_initialized(): return ["query"] else: return self.flow_config["expected_inputs"] @staticmethod def _get_message(prompt_template, input_data: Dict[str, Any]): template_kwargs = {} for input_variable in prompt_template.input_variables: template_kwargs[input_variable] = input_data[input_variable] msg_content = prompt_template.format(**template_kwargs) return msg_content def _get_demonstration_query_message_content(self, sample_data: Dict): input_variables = self.query_message_prompt_template.input_variables return self.query_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}), [] def _get_demonstration_response_message_content(self, sample_data: Dict): input_variables = self.demonstrations_response_template.input_variables return self.demonstrations_response_template.format(**{k: sample_data[k] for k in input_variables}), [] def _get_annotator_with_key(self, key: str): for _, ra in self.response_annotators.items(): if ra.key == key: return ra def _response_parsing(self, response: str, expected_outputs: List[str]): target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs] if len(target_annotators) == 0: return {expected_outputs[0]: response} parsed_outputs = {} for ra in target_annotators: parsed_out = ra(response) parsed_outputs.update(parsed_out) return parsed_outputs def _add_demonstrations(self): if self.demonstrations is not None: for example in self.demonstrations: query, parents = self._get_demonstration_query_message_content(example) response, parents = self._get_demonstration_response_message_content(example) self._log_chat_message(content=query, message_creator=self.user_name, parent_message_ids=parents) self._log_chat_message(content=response, message_creator=self.assistant_name, parent_message_ids=parents) def _log_chat_message(self, message_creator: str, content: str, parent_message_ids: List[str] = None): chat_message = ChatMessage( message_creator=message_creator, parent_message_ids=parent_message_ids, flow_runner=self.flow_config["name"], flow_run_id=self.flow_run_id, content=content ) return self._log_message(chat_message) def _initialize_conversation(self, input_data: Dict[str, Any]): # ~~~ Add the system message ~~~ system_message_content = self._get_message(self.system_message_prompt_template, input_data) self._log_chat_message(content=system_message_content, message_creator=self.system_name) # ~~~ Add the demonstration query-response tuples (if any) ~~~ self._add_demonstrations() self._update_state(update_data={"conversation_initialized": True}) def get_conversation_messages(self, message_format: Optional[str] = None): messages = self.flow_state["history"].get_chat_messages() if message_format is None: return messages elif message_format == "open_ai": processed_messages = [] for message in messages: if message.message_creator == self.system_name: processed_messages.append(SystemMessage(content=message.content)) elif message.message_creator == self.assistant_name: processed_messages.append(AIMessage(content=message.content)) elif message.message_creator == self.user_name: processed_messages.append(HumanMessage(content=message.content)) else: raise ValueError(f"Unknown name: {message.message_creator}") return processed_messages else: raise ValueError( f"Currently supported conversation message formats: 'open_ai'. '{message_format}' is not supported") def _call(self): api_key = self.flow_state["api_key"] backend = langchain.chat_models.ChatOpenAI( model_name=self.flow_config["model_name"], openai_api_key=api_key, **self.flow_config["generation_parameters"], ) messages = self.get_conversation_messages( message_format="open_ai" ) _success = False attempts = 1 error = None response = None while attempts <= self.n_api_retries: try: response = backend(messages).content _success = True break except Exception as e: log.error( f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. " f"Retrying in {self.wait_time_between_retries} seconds..." ) log.error( f"API call raised Exception with the following arguments arguments: " f"\n{self.flow_state['history'].to_string()}" ) attempts += 1 time.sleep(self.wait_time_between_retries) error = e if not _success: raise error if self.flow_config["verbose"]: messages_str = self.flow_state["history"].to_string() log.info( f"\n{colorama.Fore.MAGENTA}~~~ History [{self.flow_config['name']}] ~~~\n" f"{colorama.Style.RESET_ALL}{messages_str}" ) return response def _prepare_conversation(self, input_data: Dict[str, Any]): if self._is_conversation_initialized(): # ~~~ Check that the message has a `query` field ~~~ user_message_content = self.human_message_prompt_template.format(query=input_data["query"]) else: self._initialize_conversation(input_data) user_message_content = self._get_message(self.query_message_prompt_template, input_data) self._log_chat_message(message_creator=self.user_name, content=user_message_content) @flow_run_cache() def run(self, input_data: Dict[str, Any], expected_outputs: List[str]) -> Dict[str, Any]: # ~~~ Chat-specific preparation ~~~ self._prepare_conversation(input_data) # ~~~ Call ~~~ response = self._call() answer_message = self._log_chat_message( message_creator=self.flow_config["assistant_name"], content=response ) # ~~~ Response parsing ~~~ parsed_outputs = self._response_parsing( response=response, expected_outputs=expected_outputs ) self._update_state(update_data=parsed_outputs) if self.flow_config["verbose"]: parsed_output_messages_str = pprint.pformat({k: m for k, m in parsed_outputs.items()}, indent=4) log.info( f"\n{colorama.Fore.MAGENTA}~~~ " f"Response [{answer_message.message_creator} -- " f"{answer_message.message_id} -- " f"{answer_message.flow_run_id}] ~~~" f"\n{colorama.Fore.YELLOW}Content: {answer_message}{colorama.Style.RESET_ALL}" f"\n{colorama.Fore.YELLOW}Parsed Outputs: {parsed_output_messages_str}{colorama.Style.RESET_ALL}" ) # ~~~ The final answer should be in self.flow_state, thus allow_class_namespace=False ~~~ return self._get_keys_from_state(keys=expected_outputs, allow_class_namespace=False)