PromptQuest1 / hackaprompt /completers.py
NexAddo's picture
Duplicate from SynthmindsAI/PromptQuest
38a18ae
raw
history blame
3.58 kB
import logging
import os
from abc import ABC, abstractmethod
from gradio_client import Client
import openai
import tiktoken
from transformers import T5Tokenizer
HUB_TOKEN = os.getenv("HUB_TOKEN")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextCompletion(ABC):
@abstractmethod
def get_token_count(prompt: str) -> int:
...
@abstractmethod
def complete(prompt: str):
...
def get_score_multiplier(self) -> float:
return 1.0
class DebugCompleter(TextCompletion):
"""This is only meant for testing/debugging purposes."""
def __init__(self, model):
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model("text-davinci-003")
return len(encoding.encode(prompt))
def complete(self, prompt: str):
if "$CORRECT$" in prompt:
# returns the correct prompt
return "I have been PWNED"
elif "$ERROR$" in prompt:
# simulate an error
raise RuntimeError(f"User raised an error explicitly in {self.model}")
else:
return "Some incorrect completion"
class GPTCompletor(TextCompletion):
def __init__(self, model):
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model("text-davinci-003")
return len(encoding.encode(prompt))
def complete(self, prompt: str):
# Set the api key
openai.api_key = self.openai_api_key
# Call the API to generate a response
response = openai.Completion.create(
engine=self.model,
prompt=prompt,
max_tokens=1024,
temperature=0,
top_p=0,
frequency_penalty=1,
presence_penalty=1,
)
# Get the response text
response_text = response["choices"][0]["text"]
return response_text
class ChatGPTCompletor(TextCompletion):
def __init__(self, model):
self.openai_api_key = os.getenv("OPENAI_API_KEY")
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model(self.model)
return len(encoding.encode(prompt))
def complete(self, prompt: str):
# Set the api key
openai.api_key = self.openai_api_key
messages = [
{"role": "user", "content": prompt},
]
# Call the API to generate a response
response = openai.ChatCompletion.create(
messages=messages,
model=self.model,
temperature=0,
)
# Get the response text
response_text = response["choices"][0]["message"]["content"]
return response_text
def get_score_multiplier(self) -> float:
return 2.0
completers = {
"text-davinci-003": GPTCompletor,
"gpt-3.5-turbo": ChatGPTCompletor,
}
def get_completer(model: str):
logger.info(f"Loading completer for {model=}")
if model in ["text-davinci-003", "gpt-3.5-turbo"]:
completer = completers[model](model=model)
elif model == "gpt-debug":
# Not included in completers because we dont want it showing in the app
logger.warning("Using debugging completer...")
completer = DebugCompleter(model=model)
else:
raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")
return completer