|
|
|
|
|
from __future__ import annotations |
|
|
|
import os |
|
import yaml |
|
import openai |
|
import ast |
|
import pdb |
|
import asyncio |
|
from typing import Any, List |
|
import os |
|
import pathlib |
|
import openai |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIChat(): |
|
def __init__( |
|
self, |
|
model_name='gpt-3.5-turbo', |
|
max_tokens=2500, |
|
temperature=0, |
|
top_p=1, |
|
request_timeout=60, |
|
): |
|
openai.api_key = os.environ.get("OPENAI_API_KEY", None) |
|
assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable." |
|
if 'gpt' not in model_name: |
|
openai.api_base = "http://localhost:8000/v1" |
|
self.config = { |
|
'model_name': model_name, |
|
'max_tokens': max_tokens, |
|
'temperature': temperature, |
|
'top_p': top_p, |
|
'request_timeout': request_timeout, |
|
} |
|
|
|
|
|
def _boolean_fix(self, output): |
|
return output.replace("true", "True").replace("false", "False") |
|
|
|
def _type_check(self, output, expected_type): |
|
try: |
|
output_eval = ast.literal_eval(output) |
|
if not isinstance(output_eval, expected_type): |
|
return None |
|
return output_eval |
|
except: |
|
return None |
|
|
|
async def dispatch_openai_requests( |
|
self, |
|
messages_list, |
|
) -> list[str]: |
|
"""Dispatches requests to OpenAI API asynchronously. |
|
|
|
Args: |
|
messages_list: List of messages to be sent to OpenAI ChatCompletion API. |
|
Returns: |
|
List of responses from OpenAI API. |
|
""" |
|
async def _request_with_retry(messages, retry=3): |
|
for _ in range(retry): |
|
try: |
|
response = await openai.ChatCompletion.acreate( |
|
model=self.config['model_name'], |
|
messages=messages, |
|
max_tokens=self.config['max_tokens'], |
|
temperature=self.config['temperature'], |
|
top_p=self.config['top_p'], |
|
request_timeout=self.config['request_timeout'], |
|
) |
|
return response |
|
except openai.error.RateLimitError: |
|
print('Rate limit error, waiting for 40 second...') |
|
await asyncio.sleep(40) |
|
except openai.error.APIError: |
|
print('API error, waiting for 1 second...') |
|
await asyncio.sleep(1) |
|
except openai.error.Timeout: |
|
print('Timeout error, waiting for 1 second...') |
|
await asyncio.sleep(1) |
|
except openai.error.ServiceUnavailableError: |
|
print('Service unavailable error, waiting for 3 second...') |
|
await asyncio.sleep(3) |
|
except openai.error.APIConnectionError: |
|
print('API Connection error, waiting for 3 second...') |
|
await asyncio.sleep(3) |
|
|
|
return None |
|
|
|
async_responses = [ |
|
_request_with_retry(messages) |
|
for messages in messages_list |
|
] |
|
|
|
return await asyncio.gather(*async_responses) |
|
|
|
async def async_run(self, messages_list, expected_type): |
|
retry = 1 |
|
responses = [None for _ in range(len(messages_list))] |
|
messages_list_cur_index = [i for i in range(len(messages_list))] |
|
|
|
while retry > 0 and len(messages_list_cur_index) > 0: |
|
print(f'{retry} retry left...') |
|
messages_list_cur = [messages_list[i] for i in messages_list_cur_index] |
|
|
|
predictions = await self.dispatch_openai_requests( |
|
messages_list=messages_list_cur, |
|
) |
|
|
|
preds = [self._type_check(self._boolean_fix(prediction['choices'][0]['message']['content']), expected_type) if prediction is not None else None for prediction in predictions] |
|
|
|
finised_index = [] |
|
for i, pred in enumerate(preds): |
|
if pred is not None: |
|
responses[messages_list_cur_index[i]] = pred |
|
finised_index.append(messages_list_cur_index[i]) |
|
|
|
messages_list_cur_index = [i for i in messages_list_cur_index if i not in finised_index] |
|
|
|
retry -= 1 |
|
|
|
return responses |
|
|
|
class OpenAIEmbed(): |
|
def __init__(): |
|
openai.api_key = os.environ.get("OPENAI_API_KEY", None) |
|
assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable." |
|
|
|
async def create_embedding(self, text, retry=3): |
|
for _ in range(retry): |
|
try: |
|
response = await openai.Embedding.acreate(input=text, model="text-embedding-ada-002") |
|
return response |
|
except openai.error.RateLimitError: |
|
print('Rate limit error, waiting for 1 second...') |
|
await asyncio.sleep(1) |
|
except openai.error.APIError: |
|
print('API error, waiting for 1 second...') |
|
await asyncio.sleep(1) |
|
except openai.error.Timeout: |
|
print('Timeout error, waiting for 1 second...') |
|
await asyncio.sleep(1) |
|
return None |
|
|
|
async def process_batch(self, batch, retry=3): |
|
tasks = [self.create_embedding(text, retry=retry) for text in batch] |
|
return await asyncio.gather(*tasks) |
|
|
|
if __name__ == "__main__": |
|
chat = OpenAIChat() |
|
|
|
predictions = chat.async_run( |
|
messages_list=[ |
|
[{"role": "user", "content": "show either 'ab' or '['a']'. Do not do anything else."}], |
|
] * 20, |
|
expected_type=List, |
|
) |
|
|
|
|
|
embed = OpenAIEmbed() |
|
batch = ["string1", "string2", "string3", "string4", "string5", "string6", "string7", "string8", "string9", "string10"] |
|
embeddings = asyncio.run(embed.process_batch(batch, retry=3)) |
|
for embedding in embeddings: |
|
print(embedding["data"][0]["embedding"]) |