| import linecache |
| import re |
| from typing import Dict, List, Optional |
|
|
| import openai |
|
|
|
|
| class ChatCompletion: |
| def __init__(self, model: str = 'gpt-3.5-turbo', |
| api_key: Optional[str] = None, api_key_path: str = './openai_api_key'): |
| if api_key is None: |
| openai.api_key = api_key |
| api_key = linecache.getline(api_key_path, 2).strip('\n') |
| if len(api_key) == 0: |
| raise EnvironmentError |
| openai.api_key = api_key |
|
|
| self.model = model |
| self.system_messages = [] |
| self.user_messages = [] |
|
|
| def chat(self, msg: str, setting: Optional[str] = None, model: Optional[str] = None) -> str: |
| if self._context_length() > 2048: |
| self.reset() |
| if setting is not None: |
| if setting not in self.system_messages: |
| self.system_messages.append(setting) |
| if not self.user_messages or msg != self.user_messages[-1]: |
| self.user_messages.append(msg) |
|
|
| return self._run(model) |
|
|
| def retry(self, model: Optional[str] = None) -> str: |
| return self._run(model) |
|
|
| def reset(self): |
| self.system_messages.clear() |
| self.user_messages.clear() |
|
|
| def _make_message(self) -> List[Dict]: |
| sys_messages = [{'role': 'system', 'content': msg} for msg in self.system_messages] |
| user_messages = [{'role': 'user', 'content': msg} for msg in self.user_messages] |
| return sys_messages + user_messages |
|
|
| def _context_length(self) -> int: |
| return len(''.join(self.system_messages)) + len(''.join(self.user_messages)) |
|
|
| def _run(self, model: Optional[str] = None) -> str: |
| if model is None: |
| model = self.model |
| try: |
| response = openai.ChatCompletion.create(model=model, messages=self._make_message()) |
| ans = response['choices'][0]['message']['content'] |
| ans = re.sub(r'^\n+', '', ans) |
| except openai.error.OpenAIError as e: |
| ans = e |
| except Exception as e: |
| print(e) |
| return ans |
|
|
| def __call__(self, msg: str, setting: Optional[str] = None, model: Optional[str] = None) -> str: |
| return self.chat(msg, setting, model) |
|
|