Spaces:
Running
Running
import logging | |
import os | |
from abc import ABC, abstractmethod | |
import tiktoken | |
from gradio_client import Client | |
from openai import OpenAI | |
client = OpenAI( | |
# This is the default and can be omitted | |
api_key=os.environ.get("OPENAI_API_KEY"), | |
) | |
HUB_TOKEN = os.getenv("HUB_TOKEN") | |
FLAN_HF_SPACE = "jerpint/i-like-flan" | |
LLAMA_HF_SPACE = "meta-llama/Llama-2-7b" | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TextCompletion(ABC): | |
def get_token_count(prompt: str) -> int: | |
... | |
def complete(prompt: str): | |
... | |
def get_score_multiplier(self) -> float: | |
return 1.0 | |
class DebugCompleter(TextCompletion): | |
"""This is only meant for testing/debugging purposes.""" | |
def __init__(self, model): | |
self.model = model | |
def get_token_count(self, prompt: str) -> int: | |
encoding = tiktoken.encoding_for_model("text-davinci-003") | |
return len(encoding.encode(prompt)) | |
def complete(self, prompt: str): | |
if "$CORRECT$" in prompt: | |
# returns the correct prompt | |
return "I have been PWNED" | |
elif "$ERROR$" in prompt: | |
# simulate an error | |
raise RuntimeError(f"User raised an error explicitly in {self.model}") | |
else: | |
return "Some incorrect completion" | |
class ChatGPTCompletor(TextCompletion): | |
def __init__(self, model): | |
self.model = model | |
def get_token_count(self, prompt: str) -> int: | |
encoding = tiktoken.encoding_for_model(self.model) | |
return len(encoding.encode(prompt)) | |
def complete(self, prompt: str): | |
messages = [ | |
{"role": "user", "content": prompt}, | |
] | |
completion = client.chat.completions.create( | |
messages=messages, | |
temperature=0, | |
model=self.model, | |
) | |
# Get the response text | |
response_text = completion.choices[0].message.content | |
return response_text | |
def get_score_multiplier(self) -> float: | |
return 2.0 | |
class FlanT5Completor(TextCompletion): | |
def __init__(self, model: str): | |
self.model = model | |
self.client = Client(src=FLAN_HF_SPACE, hf_token=HUB_TOKEN) | |
def get_token_count(self, prompt: str) -> int: | |
from transformers import T5Tokenizer | |
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl") | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
token_count = len(input_ids[0]) | |
return token_count | |
def complete(self, prompt: str) -> str: | |
response_text = self.client.predict(prompt) | |
return response_text | |
class LLama2Completor(TextCompletion): | |
def __init__(self, model: str): | |
self.model = model | |
# self.client = Client(src=LLAMA_HF_SPACE, hf_token=HUB_TOKEN) | |
self.client = Client(src="https://e554306b558a76b0fd.gradio.live") | |
def get_token_count(self, prompt: str) -> int: | |
return 0 | |
def complete(self, prompt: str) -> str: | |
output = self.client.predict(prompt) | |
return output | |
class ClaudeCompletor(TextCompletion): | |
def __init__(self, model="claude-2"): | |
from anthropic import Anthropic | |
import os | |
api_key = os.environ["ANTHROPIC_API_KEY"] | |
self.model = model | |
self.anthropic = Anthropic(api_key=api_key) | |
def get_token_count(self, prompt: str): | |
"Not used in actual competition, so just dummy value" | |
return 0 | |
def complete(self, prompt: str): | |
from anthropic import HUMAN_PROMPT, AI_PROMPT | |
completion = self.anthropic.completions.create( | |
model=self.model, | |
max_tokens_to_sample=300, | |
prompt=f"{HUMAN_PROMPT} {prompt} {AI_PROMPT}", | |
) | |
return completion.completion | |
completers = { | |
"gpt-3.5-turbo": ChatGPTCompletor, | |
"gpt-4o": ChatGPTCompletor, | |
"gpt-4": ChatGPTCompletor, | |
"gpt-3.5-turbo-0125": ChatGPTCompletor, | |
"gpt-4-0613": ChatGPTCompletor, | |
"gpt-4o-2024-05-13": ChatGPTCompletor, | |
} | |
def get_completer(model: str): | |
logger.info(f"Loading completer for {model=}") | |
if model in completers.keys(): | |
completer = completers[model](model=model) | |
elif model == "gpt-debug": | |
# Not included in completers because we dont want it showing in the app | |
logger.warning("Using debugging completer...") | |
completer = DebugCompleter(model=model) | |
else: | |
raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}") | |
return completer | |