File size: 2,978 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
from abc import ABC, abstractmethod
class LM(ABC):
"""Abstract class for language models."""
def __init__(self, model):
self.kwargs = {
"model": model,
"temperature": 0.0,
"max_tokens": 150,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"n": 1,
}
self.provider = "default"
self.history = []
@abstractmethod
def basic_request(self, prompt, **kwargs):
pass
def request(self, prompt, **kwargs):
return self.basic_request(prompt, **kwargs)
def print_green(self, text: str, end: str = "\n"):
print("\x1b[32m" + str(text) + "\x1b[0m", end=end)
def print_red(self, text: str, end: str = "\n"):
print("\x1b[31m" + str(text) + "\x1b[0m", end=end)
def inspect_history(self, n: int = 1, skip: int = 0):
"""Prints the last n prompts and their completions.
TODO: print the valid choice that contains filled output field instead of the first
"""
provider: str = self.provider
last_prompt = None
printed = []
n = n + skip
for x in reversed(self.history[-100:]):
prompt = x["prompt"]
if prompt != last_prompt:
if provider=="clarifai":
printed.append(
(
prompt,
x['response']
)
)
else:
printed.append(
(
prompt,
x['response']
)
)
last_prompt = prompt
if len(printed) >= n:
break
for idx, (prompt, choices) in enumerate(reversed(printed)):
# skip the first `skip` prompts
if (n - idx - 1) < skip:
continue
print("\n\n\n")
print(prompt, end="")
text = ""
if provider == "cohere":
text = choices[0].text
elif provider == "openai" or provider == "ollama":
text = ' ' + self._get_choice_text(choices[0]).strip()
elif provider == "clarifai":
text=choices
else:
text = choices #choices[0]["text"]
self.print_green(text, end="")
if len(choices) > 1:
self.print_red(f" \t (and {len(choices)-1} other completions)", end="")
print("\n\n\n")
@abstractmethod
def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
pass
def copy(self, **kwargs):
"""Returns a copy of the language model with the same parameters."""
kwargs = {**self.kwargs, **kwargs}
model = kwargs.pop('model')
return self.__class__(model, **kwargs)
|