folietechnique / hackaprompt /completers.py
jerpint's picture
add gpt4o, remove password if env vars not set
e480e8f
import logging
import os
from abc import ABC, abstractmethod
import tiktoken
from gradio_client import Client
from openai import OpenAI
client = OpenAI(
# This is the default and can be omitted
api_key=os.environ.get("OPENAI_API_KEY"),
)
HUB_TOKEN = os.getenv("HUB_TOKEN")
FLAN_HF_SPACE = "jerpint/i-like-flan"
LLAMA_HF_SPACE = "meta-llama/Llama-2-7b"
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TextCompletion(ABC):
@abstractmethod
def get_token_count(prompt: str) -> int:
...
@abstractmethod
def complete(prompt: str):
...
def get_score_multiplier(self) -> float:
return 1.0
class DebugCompleter(TextCompletion):
"""This is only meant for testing/debugging purposes."""
def __init__(self, model):
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model("text-davinci-003")
return len(encoding.encode(prompt))
def complete(self, prompt: str):
if "$CORRECT$" in prompt:
# returns the correct prompt
return "I have been PWNED"
elif "$ERROR$" in prompt:
# simulate an error
raise RuntimeError(f"User raised an error explicitly in {self.model}")
else:
return "Some incorrect completion"
class ChatGPTCompletor(TextCompletion):
def __init__(self, model):
self.model = model
def get_token_count(self, prompt: str) -> int:
encoding = tiktoken.encoding_for_model(self.model)
return len(encoding.encode(prompt))
def complete(self, prompt: str):
messages = [
{"role": "user", "content": prompt},
]
completion = client.chat.completions.create(
messages=messages,
temperature=0,
model=self.model,
)
# Get the response text
response_text = completion.choices[0].message.content
return response_text
def get_score_multiplier(self) -> float:
return 2.0
class FlanT5Completor(TextCompletion):
def __init__(self, model: str):
self.model = model
self.client = Client(src=FLAN_HF_SPACE, hf_token=HUB_TOKEN)
def get_token_count(self, prompt: str) -> int:
from transformers import T5Tokenizer
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
token_count = len(input_ids[0])
return token_count
def complete(self, prompt: str) -> str:
response_text = self.client.predict(prompt)
return response_text
class LLama2Completor(TextCompletion):
def __init__(self, model: str):
self.model = model
# self.client = Client(src=LLAMA_HF_SPACE, hf_token=HUB_TOKEN)
self.client = Client(src="https://e554306b558a76b0fd.gradio.live")
def get_token_count(self, prompt: str) -> int:
return 0
def complete(self, prompt: str) -> str:
output = self.client.predict(prompt)
return output
class ClaudeCompletor(TextCompletion):
def __init__(self, model="claude-2"):
from anthropic import Anthropic
import os
api_key = os.environ["ANTHROPIC_API_KEY"]
self.model = model
self.anthropic = Anthropic(api_key=api_key)
def get_token_count(self, prompt: str):
"Not used in actual competition, so just dummy value"
return 0
def complete(self, prompt: str):
from anthropic import HUMAN_PROMPT, AI_PROMPT
completion = self.anthropic.completions.create(
model=self.model,
max_tokens_to_sample=300,
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}",
)
return completion.completion
completers = {
"gpt-3.5-turbo": ChatGPTCompletor,
"gpt-4o": ChatGPTCompletor,
"gpt-4": ChatGPTCompletor,
"gpt-3.5-turbo-0125": ChatGPTCompletor,
"gpt-4-0613": ChatGPTCompletor,
"gpt-4o-2024-05-13": ChatGPTCompletor,
}
def get_completer(model: str):
logger.info(f"Loading completer for {model=}")
if model in completers.keys():
completer = completers[model](model=model)
elif model == "gpt-debug":
# Not included in completers because we dont want it showing in the app
logger.warning("Using debugging completer...")
completer = DebugCompleter(model=model)
else:
raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")
return completer