Spaces:
Sleeping
Sleeping
import openai | |
import time | |
from typing import Dict, List, Optional, Union | |
from collections import defaultdict | |
class OpenAIWrapper: | |
is_api: bool = True | |
def __init__(self, | |
model: str = 'gpt-3.5-turbo-0613', | |
retry: int = 8, | |
wait: int=5, | |
verbose: bool = False, | |
system_prompt: str = None, | |
temperature: float = 0, | |
key: str = None, | |
): | |
import tiktoken | |
self.tiktoken = tiktoken | |
self.model = model | |
self.system_prompt = system_prompt | |
self.retry = retry | |
self.wait = wait | |
self.cur_idx = 0 | |
self.fail_cnt = defaultdict(lambda: 0) | |
self.fail_msg = 'Failed to obtain answer via API. ' | |
self.temperature = temperature | |
self.keys = [key] | |
self.num_keys = 1 | |
self.verbose = verbose | |
def generate_inner(self, | |
inputs: Union[str, List[str]], | |
max_out_len: int = 1024, | |
chat_mode=False, | |
temperature: float = 0) -> str: | |
input_msgs = [] | |
if self.system_prompt is not None: | |
input_msgs.append(dict(role='system', content=self.system_prompt)) | |
if isinstance(inputs, str): | |
input_msgs.append(dict(role='user', content=inputs)) | |
elif self.system_prompt is not None and isinstance(inputs, list) and len(inputs) == 0: | |
pass | |
else: | |
assert isinstance(inputs, list) and isinstance(inputs[0], str) | |
if chat_mode: | |
roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user'] | |
roles = roles * len(inputs) | |
for role, msg in zip(roles, inputs): | |
input_msgs.append(dict(role=role, content=msg)) | |
else: | |
for s in inputs: | |
input_msgs.append(dict(role='user', content=s)) | |
for i in range(self.num_keys): | |
idx = (self.cur_idx + i) % self.num_keys | |
if self.fail_cnt[idx] >= min(self.fail_cnt.values()) + 20: | |
continue | |
try: | |
openai.api_key = self.keys[idx] | |
response = openai.ChatCompletion.create( | |
model=self.model, | |
messages=input_msgs, | |
max_tokens=max_out_len, | |
n=1, | |
stop=None, | |
temperature=temperature,) | |
result = response.choices[0].message.content.strip() | |
self.cur_idx = idx | |
return result | |
except: | |
print(f'OPENAI KEY {self.keys[idx]} FAILED !!!') | |
self.fail_cnt[idx] += 1 | |
if self.verbose: | |
try: | |
print(response) | |
except: | |
pass | |
pass | |
x = 1 / 0 | |
def chat(self, inputs, max_out_len=1024, temperature=0): | |
if isinstance(inputs, str): | |
context_window = 4096 | |
if '32k' in self.model: | |
context_window = 32768 | |
elif '16k' in self.model: | |
context_window = 16384 | |
elif 'gpt-4' in self.model: | |
context_window = 8192 | |
# Will hold out 200 tokens as buffer | |
max_out_len = min(max_out_len, context_window - self.get_token_len(inputs) - 200) | |
if max_out_len < 0: | |
return self.fail_msg + 'Input string longer than context window. Length Exceeded. ' | |
assert isinstance(inputs, list) | |
for i in range(self.retry): | |
try: | |
return self.generate_inner(inputs, max_out_len, chat_mode=True, temperature=temperature) | |
except: | |
if i != self.retry - 1: | |
if self.verbose: | |
print(f'Try #{i} failed, retrying...') | |
time.sleep(self.wait) | |
pass | |
return self.fail_msg | |