import openai import time from typing import Dict, List, Optional, Union from collections import defaultdict class OpenAIWrapper: is_api: bool = True def __init__(self, model: str = 'gpt-3.5-turbo-0613', retry: int = 8, wait: int=5, verbose: bool = False, system_prompt: str = None, temperature: float = 0, key: str = None, ): import tiktoken self.tiktoken = tiktoken self.model = model self.system_prompt = system_prompt self.retry = retry self.wait = wait self.cur_idx = 0 self.fail_cnt = defaultdict(lambda: 0) self.fail_msg = 'Failed to obtain answer via API. ' self.temperature = temperature self.keys = [key] self.num_keys = 1 self.verbose = verbose def generate_inner(self, inputs: Union[str, List[str]], max_out_len: int = 1024, chat_mode=False, temperature: float = 0) -> str: input_msgs = [] if self.system_prompt is not None: input_msgs.append(dict(role='system', content=self.system_prompt)) if isinstance(inputs, str): input_msgs.append(dict(role='user', content=inputs)) elif self.system_prompt is not None and isinstance(inputs, list) and len(inputs) == 0: pass else: assert isinstance(inputs, list) and isinstance(inputs[0], str) if chat_mode: roles = ['user', 'assistant'] if len(inputs) % 2 == 1 else ['assistant', 'user'] roles = roles * len(inputs) for role, msg in zip(roles, inputs): input_msgs.append(dict(role=role, content=msg)) else: for s in inputs: input_msgs.append(dict(role='user', content=s)) for i in range(self.num_keys): idx = (self.cur_idx + i) % self.num_keys if self.fail_cnt[idx] >= min(self.fail_cnt.values()) + 20: continue try: openai.api_key = self.keys[idx] response = openai.ChatCompletion.create( model=self.model, messages=input_msgs, max_tokens=max_out_len, n=1, stop=None, temperature=temperature,) result = response.choices[0].message.content.strip() self.cur_idx = idx return result except: print(f'OPENAI KEY {self.keys[idx]} FAILED !!!') self.fail_cnt[idx] += 1 if self.verbose: try: print(response) except: pass pass x = 1 / 0 def chat(self, inputs, max_out_len=1024, temperature=0): if isinstance(inputs, str): context_window = 4096 if '32k' in self.model: context_window = 32768 elif '16k' in self.model: context_window = 16384 elif 'gpt-4' in self.model: context_window = 8192 # Will hold out 200 tokens as buffer max_out_len = min(max_out_len, context_window - self.get_token_len(inputs) - 200) if max_out_len < 0: return self.fail_msg + 'Input string longer than context window. Length Exceeded. ' assert isinstance(inputs, list) for i in range(self.retry): try: return self.generate_inner(inputs, max_out_len, chat_mode=True, temperature=temperature) except: if i != self.retry - 1: if self.verbose: print(f'Try #{i} failed, retrying...') time.sleep(self.wait) pass return self.fail_msg