|
from copy import deepcopy |
|
|
|
import hydra |
|
|
|
import os, time |
|
|
|
from typing import List, Dict, Optional, Any |
|
|
|
from langchain import PromptTemplate |
|
import langchain |
|
from langchain.schema import HumanMessage, AIMessage, SystemMessage |
|
|
|
from flows.base_flows.abstract import AtomicFlow |
|
from flows.datasets import GenericDemonstrationsDataset |
|
|
|
from flows import utils |
|
from flows.messages.flow_message import UpdateMessage_ChatMessage |
|
from flows.utils.caching_utils import flow_run_cache |
|
|
|
log = utils.get_pylogger(__name__) |
|
|
|
|
|
|
|
|
|
class OpenAIChatAtomicFlow(AtomicFlow): |
|
REQUIRED_KEYS_CONFIG = ["model_name", "generation_parameters"] |
|
REQUIRED_KEYS_KWARGS = ["system_message_prompt_template", |
|
"human_message_prompt_template", |
|
"query_message_prompt_template"] |
|
|
|
SUPPORTS_CACHING: bool = True |
|
|
|
system_message_prompt_template: PromptTemplate |
|
human_message_prompt_template: PromptTemplate |
|
|
|
query_message_prompt_template: Optional[PromptTemplate] = None |
|
demonstrations: GenericDemonstrationsDataset = None |
|
demonstrations_response_template: PromptTemplate = None |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
assert self.flow_config["name"] not in [ |
|
"system", |
|
"user", |
|
"assistant", |
|
], f"Flow name '{self.flow_config['name']}' cannot be 'system', 'user' or 'assistant'" |
|
|
|
def set_up_flow_state(self): |
|
super().set_up_flow_state() |
|
self.flow_state["previous_messages"] = [] |
|
|
|
@classmethod |
|
def _set_up_prompts(cls, config): |
|
kwargs = {} |
|
|
|
kwargs["system_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['system_message_prompt_template'], _convert_="partial") |
|
kwargs["query_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['query_message_prompt_template'], _convert_="partial") |
|
kwargs["human_message_prompt_template"] = \ |
|
hydra.utils.instantiate(config['human_message_prompt_template'], _convert_="partial") |
|
|
|
return kwargs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod |
|
def instantiate_from_config(cls, config): |
|
flow_config = deepcopy(config) |
|
|
|
kwargs = {"flow_config": flow_config} |
|
kwargs["input_data_transformations"] = cls._set_up_data_transformations(config["input_data_transformations"]) |
|
kwargs["output_data_transformations"] = cls._set_up_data_transformations(config["output_data_transformations"]) |
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
|
|
|
|
|
|
|
|
|
|
return cls(**kwargs) |
|
|
|
def _is_conversation_initialized(self): |
|
if len(self.flow_state["previous_messages"]) > 0: |
|
return True |
|
|
|
return False |
|
|
|
def get_input_keys(self, data: Optional[Dict[str, Any]] = None): |
|
"""Returns the expected inputs for the flow given the current state and, optionally, the input data""" |
|
if self._is_conversation_initialized(): |
|
return self.flow_config["default_human_input_keys"] |
|
else: |
|
return self.flow_config["input_keys"] |
|
|
|
@staticmethod |
|
def _get_message(prompt_template, input_data: Dict[str, Any]): |
|
template_kwargs = {} |
|
for input_variable in prompt_template.input_variables: |
|
template_kwargs[input_variable] = input_data[input_variable] |
|
|
|
msg_content = prompt_template.format(**template_kwargs) |
|
return msg_content |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _state_update_add_chat_message(self, |
|
role: str, |
|
content: str) -> None: |
|
|
|
|
|
if role == self.flow_config["system_name"]: |
|
self.flow_state["previous_messages"].append(SystemMessage(content=content)) |
|
elif role == self.flow_config["user_name"]: |
|
self.flow_state["previous_messages"].append(HumanMessage(content=content)) |
|
elif role == self.flow_config["assistant_name"]: |
|
self.flow_state["previous_messages"].append(AIMessage(content=content)) |
|
else: |
|
raise Exception(f"Invalid role: `{role}`.\n" |
|
f"Role should be one of: " |
|
f"`{self.flow_config['system_name']}`, " |
|
f"`{self.flow_config['user_name']}`, " |
|
f"`{self.flow_config['assistant_name']}`") |
|
|
|
|
|
chat_message = UpdateMessage_ChatMessage( |
|
created_by=self.flow_config["name"], |
|
updated_flow=self.flow_config["name"], |
|
role=role, |
|
content=content, |
|
) |
|
self._log_message(chat_message) |
|
|
|
def _call(self, api_key: str): |
|
|
|
backend = langchain.chat_models.ChatOpenAI( |
|
model_name=self.flow_config["model_name"], |
|
openai_api_key=api_key, |
|
**self.flow_config["generation_parameters"], |
|
) |
|
|
|
messages = self.flow_state["previous_messages"] |
|
|
|
_success = False |
|
attempts = 1 |
|
error = None |
|
response = None |
|
while attempts <= self.flow_config['n_api_retries']: |
|
try: |
|
response = backend(messages).content |
|
_success = True |
|
break |
|
except Exception as e: |
|
log.error( |
|
f"Error {attempts} in calling backend: {e}. Key used: `{api_key}`. " |
|
f"Retrying in {self.flow_config['wait_time_between_retries']} seconds..." |
|
) |
|
|
|
|
|
|
|
|
|
attempts += 1 |
|
time.sleep(self.flow_config['wait_time_between_retries']) |
|
error = e |
|
|
|
if not _success: |
|
raise error |
|
|
|
return response |
|
|
|
def _initialize_conversation(self, input_data: Dict[str, Any]): |
|
|
|
system_message_content = self._get_message(self.system_message_prompt_template, input_data) |
|
|
|
self._state_update_add_chat_message(content=system_message_content, |
|
role=self.flow_config["system_name"]) |
|
|
|
|
|
|
|
|
|
|
|
def _process_input(self, input_data: Dict[str, Any]): |
|
if self._is_conversation_initialized(): |
|
|
|
user_message_content = self._get_message(self.human_message_prompt_template, input_data) |
|
|
|
else: |
|
|
|
self._initialize_conversation(input_data) |
|
|
|
user_message_content = self._get_message(self.query_message_prompt_template, input_data) |
|
|
|
self._state_update_add_chat_message(role=self.flow_config["user_name"], |
|
content=user_message_content) |
|
|
|
@flow_run_cache() |
|
def run(self, |
|
input_data: Dict[str, Any], |
|
private_keys: Optional[List[str]] = [], |
|
keys_to_ignore_for_hash: Optional[List[str]] = []) -> Dict[str, Any]: |
|
|
|
api_key = self.flow_config.get("api_key", "") |
|
if "api_key" in input_data: |
|
api_key = input_data.pop("api_key") |
|
|
|
|
|
self._process_input(input_data) |
|
|
|
|
|
response = self._call(api_key) |
|
self._state_update_add_chat_message( |
|
role=self.flow_config["assistant_name"], |
|
content=response |
|
) |
|
|
|
return response |
|
|