Spaces:
Running
Running
from typing import List, Union, Optional, Literal | |
import dataclasses | |
from tenacity import ( | |
retry, | |
stop_after_attempt, # type: ignore | |
wait_random_exponential, # type: ignore | |
) | |
import openai | |
MessageRole = Literal["system", "user", "assistant"] | |
class Message(): | |
role: MessageRole | |
content: str | |
def message_to_str(message: Message) -> str: | |
return f"{message.role}: {message.content}" | |
def messages_to_str(messages: List[Message]) -> str: | |
return "\n".join([message_to_str(message) for message in messages]) | |
def gpt_completion( | |
model: str, | |
prompt: str, | |
max_tokens: int = 1024, | |
stop_strs: Optional[List[str]] = None, | |
temperature: float = 0.0, | |
num_comps=1, | |
) -> Union[List[str], str]: | |
response = openai.Completion.create( | |
model=model, | |
prompt=prompt, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=1, | |
frequency_penalty=0.0, | |
presence_penalty=0.0, | |
stop=stop_strs, | |
n=num_comps, | |
) | |
if num_comps == 1: | |
return response.choices[0].text # type: ignore | |
return [choice.text for choice in response.choices] # type: ignore | |
def gpt_chat( | |
model: str, | |
messages: List, | |
max_tokens: int = 1024, | |
temperature: float = 0.0, | |
num_comps=1, | |
) -> Union[List[str], str]: | |
try: | |
response = openai.ChatCompletion.create( | |
model=model, | |
messages=[dataclasses.asdict(message) for message in messages], | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=1, | |
frequency_penalty=0.0, | |
presence_penalty=0.0, | |
n=num_comps, | |
) | |
if num_comps == 1: | |
return response.choices[0].message.content # type: ignore | |
return [choice.message.content for choice in response.choices] # type: ignore | |
except Exception as e: | |
print(f"An error occurred while calling OpenAI: {e}") | |
raise | |
class ModelBase(): | |
def __init__(self, name: str): | |
self.name = name | |
self.is_chat = False | |
def __repr__(self) -> str: | |
return f'{self.name}' | |
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: | |
raise NotImplementedError | |
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]: | |
raise NotImplementedError | |
class GPTChat(ModelBase): | |
def __init__(self, model_name: str): | |
self.name = model_name | |
self.is_chat = True | |
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]: | |
return gpt_chat(self.name, messages, max_tokens, temperature, num_comps) | |
class GPT4(GPTChat): | |
def __init__(self): | |
super().__init__("gpt-4") | |
class GPT35(GPTChat): | |
def __init__(self): | |
super().__init__("gpt-3.5-turbo") | |
class GPTDavinci(ModelBase): | |
def __init__(self, model_name: str): | |
self.name = model_name | |
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]: | |
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps) |