from typing import List, Union, Optional, Literal import dataclasses from tenacity import ( retry, stop_after_attempt, # type: ignore wait_random_exponential, # type: ignore ) import openai MessageRole = Literal["system", "user", "assistant"] @dataclasses.dataclass() class Message(): role: MessageRole content: str def message_to_str(message: Message) -> str: return f"{message.role}: {message.content}" def messages_to_str(messages: List[Message]) -> str: return "\n".join([message_to_str(message) for message in messages]) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) def gpt_completion( model: str, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1, ) -> Union[List[str], str]: response = openai.Completion.create( model=model, prompt=prompt, temperature=temperature, max_tokens=max_tokens, top_p=1, frequency_penalty=0.0, presence_penalty=0.0, stop=stop_strs, n=num_comps, ) if num_comps == 1: return response.choices[0].text # type: ignore return [choice.text for choice in response.choices] # type: ignore @retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6)) def gpt_chat( model: str, messages: List, max_tokens: int = 1024, temperature: float = 0.0, num_comps=1, ) -> Union[List[str], str]: try: response = openai.ChatCompletion.create( model=model, messages=[dataclasses.asdict(message) for message in messages], max_tokens=max_tokens, temperature=temperature, top_p=1, frequency_penalty=0.0, presence_penalty=0.0, n=num_comps, ) if num_comps == 1: return response.choices[0].message.content # type: ignore return [choice.message.content for choice in response.choices] # type: ignore except Exception as e: print(f"An error occurred while calling OpenAI: {e}") raise class ModelBase(): def __init__(self, name: str): self.name = name self.is_chat = False def __repr__(self) -> str: return f'{self.name}' def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: raise NotImplementedError def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]: raise NotImplementedError class GPTChat(ModelBase): def __init__(self, model_name: str): self.name = model_name self.is_chat = True def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: return gpt_chat(self.name, messages, max_tokens, temperature, num_comps) class GPT4(GPTChat): def __init__(self): super().__init__("gpt-4") class GPT35(GPTChat): def __init__(self): super().__init__("gpt-3.5-turbo") class GPTDavinci(ModelBase): def __init__(self, model_name: str): self.name = model_name def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]: return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps)