lindsay-qu's picture
Update models/gpt4_model.py
8145c86 verified
from .base_model import BaseModel
import openai
from openai import AsyncOpenAI, OpenAI
from tqdm import tqdm
import asyncio
import os
import traceback
class GPT4Model(BaseModel):
def __init__(self,
generation_model="gpt-4-vision-preview",
embedding_model="text-embedding-ada-002",
temperature=0,
) -> None:
self.generation_model = generation_model
self.embedding_model = embedding_model
self.temperature = temperature
async def respond_async(self, messages: list[dict]) -> str:
client = AsyncOpenAI(
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"]
)
try:
output = await client.chat.completions.create(
messages=messages,
model=self.generation_model,
temperature=self.temperature,
max_tokens=1000,
)
response = output.choices[0].message.content
except:
try:
output = await client.chat.completions.create(
messages=messages,
model=self.generation_model,
temperature=self.temperature,
max_tokens=1000,
)
response = output.choices[0].message.content
except:
response = "No answer provided."
return response
return response
def respond(self, messages: list[dict]) -> str:
client = OpenAI(
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"]
)
# OpenAI.api_key=os.environ["OPENAI_API_KEY"]
# OpenAI.api_base=os.environ["OPENAI_API_BASE"]
try:
response = client.chat.completions.create(
messages=messages,
model=self.generation_model,
temperature=self.temperature,
max_tokens=1000,
).choices[0].message.content
except:
try:
response = client.chat.completions.create(
messages=messages,
model=self.generation_model,
temperature=self.temperature,
max_tokens=1000,
).choices[0].message.content
except:
print(traceback.format_exc())
response = "No answer provided."
return response
def embedding(self, texts: list[str]) -> list[float]:
client = OpenAI(
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"]
)
data = []
# print(f"{self.embedding_model} Embedding:")
for i in range(0, len(texts), 2048):
lower = i
upper = min(i+2048, len(texts))
data += client.embeddings.create(input=texts[lower:upper],
model=self.embedding_model
).data
embeddings = [d.embedding for d in data]
return embeddings