CodeLATS / generators /model.py
Ron
initial commit
41d1bc5
from typing import List, Union, Optional, Literal
import dataclasses
from tenacity import (
retry,
stop_after_attempt, # type: ignore
wait_random_exponential, # type: ignore
)
import openai
MessageRole = Literal["system", "user", "assistant"]
@dataclasses.dataclass()
class Message():
role: MessageRole
content: str
def message_to_str(message: Message) -> str:
return f"{message.role}: {message.content}"
def messages_to_str(messages: List[Message]) -> str:
return "\n".join([message_to_str(message) for message in messages])
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gpt_completion(
model: str,
prompt: str,
max_tokens: int = 1024,
stop_strs: Optional[List[str]] = None,
temperature: float = 0.0,
num_comps=1,
) -> Union[List[str], str]:
response = openai.Completion.create(
model=model,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
stop=stop_strs,
n=num_comps,
)
if num_comps == 1:
return response.choices[0].text # type: ignore
return [choice.text for choice in response.choices] # type: ignore
@retry(wait=wait_random_exponential(min=1, max=180), stop=stop_after_attempt(6))
def gpt_chat(
model: str,
messages: List,
max_tokens: int = 1024,
temperature: float = 0.0,
num_comps=1,
) -> Union[List[str], str]:
try:
response = openai.ChatCompletion.create(
model=model,
messages=[dataclasses.asdict(message) for message in messages],
max_tokens=max_tokens,
temperature=temperature,
top_p=1,
frequency_penalty=0.0,
presence_penalty=0.0,
n=num_comps,
)
if num_comps == 1:
return response.choices[0].message.content # type: ignore
return [choice.message.content for choice in response.choices] # type: ignore
except Exception as e:
print(f"An error occurred while calling OpenAI: {e}")
raise
class ModelBase():
def __init__(self, name: str):
self.name = name
self.is_chat = False
def __repr__(self) -> str:
return f'{self.name}'
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
raise NotImplementedError
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0.0, num_comps=1) -> Union[List[str], str]:
raise NotImplementedError
class GPTChat(ModelBase):
def __init__(self, model_name: str):
self.name = model_name
self.is_chat = True
def generate_chat(self, messages: List[Message], max_tokens: int = 1024, temperature: float = 0.2, num_comps: int = 1) -> Union[List[str], str]:
return gpt_chat(self.name, messages, max_tokens, temperature, num_comps)
class GPT4(GPTChat):
def __init__(self):
super().__init__("gpt-4")
class GPT35(GPTChat):
def __init__(self):
super().__init__("gpt-3.5-turbo")
class GPTDavinci(ModelBase):
def __init__(self, model_name: str):
self.name = model_name
def generate(self, prompt: str, max_tokens: int = 1024, stop_strs: Optional[List[str]] = None, temperature: float = 0, num_comps=1) -> Union[List[str], str]:
return gpt_completion(self.name, prompt, max_tokens, stop_strs, temperature, num_comps)