File size: 7,748 Bytes
8360ec7 eca534f 8360ec7 eca534f 8360ec7 eca534f 8360ec7 eca534f 8360ec7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
# the async version is adapted from https://gist.github.com/neubig/80de662fb3e225c18172ec218be4917a
from __future__ import annotations
import os
import yaml
import openai
import ast
import pdb
import asyncio
from typing import Any, List
import os
import pathlib
import openai
from openai import OpenAI, AsyncOpenAI
import re
class OpenAIChat():
def __init__(
self,
model_name='gpt-3.5-turbo',
max_tokens=2500,
temperature=0,
top_p=1,
request_timeout=120,
):
if 'gpt' not in model_name:
openai.api_base = "http://localhost:8000/v1"
else:
#openai.api_base = "https://api.openai.com/v1"
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable."
assert openai.api_key !='', "Please set the OPENAI_API_KEY environment variable."
self.client = AsyncOpenAI()
self.config = {
'model_name': model_name,
'max_tokens': max_tokens,
'temperature': temperature,
'top_p': top_p,
'request_timeout': request_timeout,
}
def extract_list_from_string(self, input_string):
# pattern = r'\[.*\]'
# result = re.search(pattern, input_string)
# if result:
# return result.group()
# else:
# return None
start_index = input_string.find('[')
end_index = input_string.rfind(']')
if start_index != -1 and end_index != -1 and start_index < end_index:
return input_string[start_index:end_index + 1]
else:
return None
def extract_dict_from_string(self, input_string):
start_index = input_string.find('{')
end_index = input_string.rfind('}')
if start_index != -1 and end_index != -1 and start_index < end_index:
return input_string[start_index:end_index + 1]
else:
return None
def _boolean_fix(self, output):
return output.replace("true", "True").replace("false", "False")
def _type_check(self, output, expected_type):
try:
output_eval = ast.literal_eval(output)
if not isinstance(output_eval, expected_type):
return None
return output_eval
except:
'''
if(expected_type == List):
valid_output = self.extract_list_from_string(output)
output_eval = ast.literal_eval(valid_output)
if not isinstance(output_eval, expected_type):
return None
return output_eval
elif(expected_type == dict):
valid_output = self.extract_dict_from_string(output)
output_eval = ast.literal_eval(valid_output)
if not isinstance(output_eval, expected_type):
return None
return output_eval
'''
return None
async def dispatch_openai_requests(self, messages_list,) -> list[str]:
"""
Dispatches requests to OpenAI API asynchronously.
Args:
messages_list: List of messages to be sent to OpenAI ChatCompletion API.
Returns:
List of responses from OpenAI API.
"""
async def _request_with_retry(messages, retry=3):
for attempt in range(retry):
try:
response = await self.client.chat.completions.create(
model=self.config['model_name'],
messages=messages,
max_tokens=self.config['max_tokens'],
temperature=self.config['temperature'],
top_p=self.config['top_p']
)
return response
except openai.RateLimitError as e:
await asyncio.sleep((2 ** attempt) * 0.5) # exponential backoff
except (openai.Timeout, openai.APIError) as e:
await asyncio.sleep((2 ** attempt) * 0.5) # exponential backoff
except Exception as e:
# Log unexpected exception for further investigation
await asyncio.sleep((2 ** attempt) * 0.5) # fallback in case of unknown errors
raise RuntimeError("All retries failed for OpenAI API request")
async_responses = [
_request_with_retry(messages)
for messages in messages_list
]
return await asyncio.gather(*async_responses, return_exceptions=True)
def run(self, messages_list, expected_type):
retry = 1
responses = [None for _ in range(len(messages_list))]
messages_list_cur_index = [i for i in range(len(messages_list))]
while retry > 0 and len(messages_list_cur_index) > 0:
messages_list_cur = [messages_list[i] for i in messages_list_cur_index]
predictions = asyncio.run(self.dispatch_openai_requests(
messages_list=messages_list_cur,
))
preds = [self._type_check(self._boolean_fix(prediction.choices[0].message.content), expected_type) if prediction is not None else None for prediction in predictions]
finised_index = []
for i, pred in enumerate(preds):
if pred is not None:
responses[messages_list_cur_index[i]] = pred
finised_index.append(messages_list_cur_index[i])
messages_list_cur_index = [i for i in messages_list_cur_index if i not in finised_index]
retry -= 1
return responses
# class OpenAIEmbed():
# def __init__():
# openai.api_key = os.environ.get("OPENAI_API_KEY", None)
# assert openai.api_key is not None, "Please set the OPENAI_API_KEY environment variable."
# assert openai.api_key != '', "Please set the OPENAI_API_KEY environment variable."
# async def create_embedding(self, text, retry=3):
# for _ in range(retry):
# try:
# response = await openai.Embedding.acreate(input=text, model="text-embedding-ada-002")
# return response
# except openai.error.RateLimitError:
# print('Rate limit error, waiting for 1 second...')
# await asyncio.sleep(1)
# except openai.error.APIError:
# print('API error, waiting for 1 second...')
# await asyncio.sleep(1)
# except openai.error.Timeout:
# print('Timeout error, waiting for 1 second...')
# await asyncio.sleep(1)
# return None
# async def process_batch(self, batch, retry=3):
# tasks = [self.create_embedding(text, retry=retry) for text in batch]
# return await asyncio.gather(*tasks)
# if __name__ == "__main__":
# chat = OpenAIChat(model_name='llama-2-7b-chat-hf')
# predictions = asyncio.run(chat.async_run(
# messages_list=[
# [{"role": "user", "content": "show either 'ab' or '['a']'. Do not do anything else."}],
# ] * 20,
# expected_type=List,
# ))
# print(predictions)
# Usage
# embed = OpenAIEmbed()
# batch = ["string1", "string2", "string3", "string4", "string5", "string6", "string7", "string8", "string9", "string10"] # Your batch of strings
# embeddings = asyncio.run(embed.process_batch(batch, retry=3))
# for embedding in embeddings:
# print(embedding["data"][0]["embedding"]) |