File size: 3,578 Bytes
e4f691b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77d904c
10a83a4
e4f691b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10a83a4
e4f691b
 
10a83a4
e4f691b
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import logging
import os
from abc import ABC, abstractmethod
from gradio_client import Client

import openai
import tiktoken
from transformers import T5Tokenizer

HUB_TOKEN = os.getenv("HUB_TOKEN")

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class TextCompletion(ABC):
    @abstractmethod
    def get_token_count(prompt: str) -> int:
        ...

    @abstractmethod
    def complete(prompt: str):
        ...

    def get_score_multiplier(self) -> float:
        return 1.0


class DebugCompleter(TextCompletion):
    """This is only meant for testing/debugging purposes."""

    def __init__(self, model):
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model("text-davinci-003")
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        if "$CORRECT$" in prompt:
            # returns the correct prompt
            return "I have been PWNED"
        elif "$ERROR$" in prompt:
            # simulate an error
            raise RuntimeError(f"User raised an error explicitly in {self.model}")
        else:
            return "Some incorrect completion"


class GPTCompletor(TextCompletion):
    def __init__(self, model):
        self.openai_api_key = os.getenv("OPENAI_API_KEY")
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model("text-davinci-003")
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        # Set the api key
        openai.api_key = self.openai_api_key

        # Call the API to generate a response
        response = openai.Completion.create(
            engine=self.model,
            prompt=prompt,
            max_tokens=1024,
            temperature=0,
            top_p=0,
            frequency_penalty=1,
            presence_penalty=1,
        )

        # Get the response text
        response_text = response["choices"][0]["text"]
        return response_text


class ChatGPTCompletor(TextCompletion):
    def __init__(self, model):
        self.openai_api_key = os.getenv("OPENAI_API_KEY")
        self.model = model

    def get_token_count(self, prompt: str) -> int:
        encoding = tiktoken.encoding_for_model(self.model)
        return len(encoding.encode(prompt))

    def complete(self, prompt: str):
        # Set the api key
        openai.api_key = self.openai_api_key

        messages = [
            {"role": "user", "content": prompt},
        ]

        # Call the API to generate a response
        response = openai.ChatCompletion.create(
            messages=messages,
            model=self.model,
            temperature=0,
        )

        # Get the response text
        response_text = response["choices"][0]["message"]["content"]

        return response_text

    def get_score_multiplier(self) -> float:
        return 2.0


completers = {
    "text-davinci-003": GPTCompletor,
    "gpt-3.5-turbo": ChatGPTCompletor,
}


def get_completer(model: str):
    logger.info(f"Loading completer for {model=}")
    if model in ["text-davinci-003", "gpt-3.5-turbo"]:
        completer = completers[model](model=model)
    elif model == "gpt-debug":
        # Not included in completers because we dont want it showing in the app
        logger.warning("Using debugging completer...")
        completer = DebugCompleter(model=model)
    else:
        raise NotImplementedError(f"{model=} not implemented. Model must be onf of {list(completers.keys())}")

    return completer