|
import pprint |
|
from copy import deepcopy |
|
|
|
import hydra |
|
|
|
import colorama |
|
import time |
|
|
|
from typing import List, Dict, Optional, Any |
|
|
|
from langchain import PromptTemplate |
|
import langchain |
|
from langchain.schema import HumanMessage, AIMessage, SystemMessage |
|
|
|
from flows.history import FlowHistory |
|
from flows.message_annotators.abstract import MessageAnnotator |
|
from flows.base_flows.abstract import AtomicFlow |
|
from flows.datasets import GenericDemonstrationsDataset |
|
|
|
from flows import utils |
|
from flows.messages.chat_message import ChatMessage |
|
from flows.utils.caching_utils import flow_run_cache |
|
|
|
log = utils.get_pylogger(__name__) |
|
|
|
|
|
class OpenAIChatAtomicFlow(AtomicFlow): |
|
model_name: str |
|
generation_parameters: Dict |
|
|
|
system_message_prompt_template: PromptTemplate |
|
human_message_prompt_template: PromptTemplate |
|
|
|
system_name: str = "system" |
|
user_name: str = "user" |
|
assistant_name: str = "assistant" |
|
|
|
n_api_retries: int = 6 |
|
wait_time_between_retries: int = 20 |
|
|
|
query_message_prompt_template: Optional[PromptTemplate] = None |
|
demonstrations: GenericDemonstrationsDataset = None |
|
demonstrations_response_template: PromptTemplate = None |
|
response_annotators: Optional[Dict[str, MessageAnnotator]] = {} |
|
|
|
def __init__(self, **kwargs): |
|
self._validate_parameters(kwargs) |
|
super().__init__(**kwargs) |
|
|
|
assert self.flow_config["name"] not in [ |
|
"system", |
|
"user", |
|
"assistant", |
|
], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" |
|
|
|
def set_up_flow_state(self): |
|
super().set_up_flow_state() |
|
self.flow_state["conversation_initialized"] = False |
|
|
|
@classmethod |
|
def _validate_parameters(cls, kwargs): |
|
|
|
super()._validate_parameters(kwargs) |
|
|
|
|
|
if "model_name" not in kwargs["flow_config"]: |
|
raise KeyError("model_name not specified in the flow_config.") |
|
|
|
if "generation_parameters" not in kwargs["flow_config"]: |
|
raise KeyError("generation_parameters not specified in the flow_config.") |
|
|
|
|
|
if "system_message_prompt_template" not in kwargs: |
|
raise KeyError("system_message_prompt_template not passed to the constructor.") |
|
|
|
if "query_message_prompt_template" not in kwargs: |
|
raise KeyError("query_message_prompt_template not passed to the constructor.") |
|
|
|
if "human_message_prompt_template" not in kwargs: |
|
raise KeyError("human_message_prompt_template not passed to the constructor.") |
|
|
|
@classmethod |
|
def _set_up_prompts(cls, config): |
|
kwargs = {} |
|
|
|
kwargs["system_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") |
|
kwargs["query_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial") |
|
kwargs["human_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") |
|
|
|
return kwargs |
|
|
|
@classmethod |
|
def _set_up_demonstration_templates(cls, config): |
|
kwargs = {} |
|
|
|
if "demonstrations_response_template" in config: |
|
kwargs["demonstrations_response_template"] = \ |
|
hydra.utils.instantiate(config['demonstrations_response_template'], _convert_="partial") |
|
|
|
return kwargs |
|
|
|
@classmethod |
|
def _set_up_response_annotators(cls, config): |
|
response_annotators = config.get("response_annotators", {}) |
|
if len(response_annotators) > 0: |
|
for key, config in response_annotators.items(): |
|
response_annotators[key] = hydra.utils.instantiate(config, _convert_="partial") |
|
return {"response_annotators": response_annotators} |
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config): |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
|
|
|
|
kwargs.update(cls._set_up_demonstration_templates(flow_config)) |
|
|
|
|
|
kwargs.update(cls._set_up_response_annotators(flow_config)) |
|
|
|
|
|
return cls(**kwargs) |
|
|
|
def _is_conversation_initialized(self): |
|
return self.flow_state["conversation_initialized"] |
|
|
|
def expected_inputs_given_state(self): |
|
if self._is_conversation_initialized(): |
|
return ["query"] |
|
else: |
|
return self.flow_config["expected_inputs"] |
|
|
|
@staticmethod |
|
def _get_message(prompt_template, input_data: Dict[str, Any]): |
|
template_kwargs = {} |
|
for input_variable in prompt_template.input_variables: |
|
template_kwargs[input_variable] = input_data[input_variable] |
|
|
|
msg_content = prompt_template.format(**template_kwargs) |
|
return msg_content |
|
|
|
def _get_demonstration_query_message_content(self, sample_data: Dict): |
|
input_variables = self.query_message_prompt_template.input_variables |
|
return self.query_message_prompt_template.format(**{k: sample_data[k] for k in input_variables}), [] |
|
|
|
def _get_demonstration_response_message_content(self, sample_data: Dict): |
|
input_variables = self.demonstrations_response_template.input_variables |
|
return self.demonstrations_response_template.format(**{k: sample_data[k] for k in input_variables}), [] |
|
|
|
def _get_annotator_with_key(self, key: str): |
|
for _, ra in self.response_annotators.items(): |
|
if ra.key == key: |
|
return ra |
|
|
|
def _response_parsing(self, response: str, expected_outputs: List[str]): |
|
target_annotators = [ra for _, ra in self.response_annotators.items() if ra.key in expected_outputs] |
|
|
|
if len(target_annotators) == 0: |
|
return {expected_outputs[0]: response} |
|
|
|
parsed_outputs = {} |
|
for ra in target_annotators: |
|
parsed_out = ra(response) |
|
parsed_outputs.update(parsed_out) |
|
return parsed_outputs |
|
|
|
def _add_demonstrations(self): |
|
if self.demonstrations is not None: |
|
for example in self.demonstrations: |
|
query, parents = self._get_demonstration_query_message_content(example) |
|
response, parents = self._get_demonstration_response_message_content(example) |
|
|
|
self._log_chat_message(content=query, |
|
message_creator=self.user_name, |
|
parent_message_ids=parents) |
|
|
|
self._log_chat_message(content=response, |
|
message_creator=self.assistant_name, |
|
parent_message_ids=parents) |
|
|
|
def _log_chat_message(self, message_creator: str, content: str, parent_message_ids: List[str] = None): |
|
chat_message = ChatMessage( |
|
message_creator=message_creator, |
|
parent_message_ids=parent_message_ids, |
|
flow_runner=self.flow_config["name"], |
|
flow_run_id=self.flow_run_id, |
|
content=content |
|
) |
|
return self._log_message(chat_message) |
|
|
|
def _initialize_conversation(self, input_data: Dict[str, Any]): |
|
|
|
system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
|
|
|
self._log_chat_message(content=system_message_content, |
|
message_creator=self.system_name) |
|
|
|
|
|
self._add_demonstrations() |
|
self._update_state(update_data={"conversation_initialized": True}) |
|
|
|
def get_conversation_messages(self, message_format: Optional[str] = None): |
|
messages = self.flow_state["history"].get_chat_messages() |
|
|
|
if message_format is None: |
|
return messages |
|
|
|
elif message_format == "open_ai": |
|
processed_messages = [] |
|
|
|
for message in messages: |
|
if message.message_creator == self.system_name: |
|
processed_messages.append(SystemMessage(content=message.content)) |
|
elif message.message_creator == self.assistant_name: |
|
processed_messages.append(AIMessage(content=message.content)) |
|
elif message.message_creator == self.user_name: |
|
processed_messages.append(HumanMessage(content=message.content)) |
|
else: |
|
raise ValueError(f"Unknown name: {message.message_creator}") |
|
return processed_messages |
|
else: |
|
raise ValueError( |
|
f"Currently supported conversation message formats: 'open_ai'. '{message_format}' is not supported") |
|
|
|
def _call(self): |
|
api_key = self.flow_state["api_key"] |
|
|
|
backend = langchain.chat_models.ChatOpenAI( |
|
model_name=self.flow_config["model_name"], |
|
openai_api_key=api_key, |
|
**self.flow_config["generation_parameters"], |
|
) |
|
|
|
messages = self.get_conversation_messages( |
|
message_format="open_ai" |
|
) |
|
|
|
_success = False |
|
attempts = 1 |
|
error = None |
|
response = None |
|
while attempts <= self.n_api_retries: |
|
try: |
|
response = backend(messages).content |
|
_success = True |
|
break |
|
except Exception as e: |
|
log.error( |
|
f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. " |
|
f"Retrying in {self.wait_time_between_retries} seconds..." |
|
) |
|
log.error( |
|
f"API call raised Exception with the following arguments arguments: " |
|
f"\n{self.flow_state['history'].to_string()}" |
|
) |
|
attempts += 1 |
|
time.sleep(self.wait_time_between_retries) |
|
error = e |
|
|
|
if not _success: |
|
raise error |
|
|
|
if self.flow_config["verbose"]: |
|
messages_str = self.flow_state["history"].to_string() |
|
log.info( |
|
f"\n{colorama.Fore.MAGENTA}~~~ History [{self.flow_config['name']}] ~~~\n" |
|
f"{colorama.Style.RESET_ALL}{messages_str}" |
|
) |
|
|
|
return response |
|
|
|
def _prepare_conversation(self, input_data: Dict[str, Any]): |
|
if self._is_conversation_initialized(): |
|
|
|
user_message_content = self.human_message_prompt_template.format(query=input_data["query"]) |
|
|
|
else: |
|
self._initialize_conversation(input_data) |
|
user_message_content = self._get_message(self.query_message_prompt_template, input_data) |
|
|
|
self._log_chat_message(message_creator=self.user_name, |
|
content=user_message_content) |
|
|
|
@flow_run_cache() |
|
def run(self, input_data: Dict[str, Any], expected_outputs: List[str]) -> Dict[str, Any]: |
|
|
|
self._prepare_conversation(input_data) |
|
|
|
|
|
response = self._call() |
|
answer_message = self._log_chat_message( |
|
message_creator=self.flow_config["assistant_name"], |
|
content=response |
|
) |
|
|
|
|
|
parsed_outputs = self._response_parsing( |
|
response=response, |
|
expected_outputs=expected_outputs |
|
) |
|
self._update_state(update_data=parsed_outputs) |
|
|
|
if self.flow_config["verbose"]: |
|
parsed_output_messages_str = pprint.pformat({k: m for k, m in parsed_outputs.items()}, |
|
indent=4) |
|
log.info( |
|
f"\n{colorama.Fore.MAGENTA}~~~ " |
|
f"Response [{answer_message.message_creator} -- " |
|
f"{answer_message.message_id} -- " |
|
f"{answer_message.flow_run_id}] ~~~" |
|
f"\n{colorama.Fore.YELLOW}Content: {answer_message}{colorama.Style.RESET_ALL}" |
|
f"\n{colorama.Fore.YELLOW}Parsed Outputs: {parsed_output_messages_str}{colorama.Style.RESET_ALL}" |
|
) |
|
|
|
|
|
return self._get_keys_from_state(keys=expected_outputs, allow_class_namespace=False) |
|
|