CoI_Agent / LLM.py
jianghuyihei's picture
update api
1099179
raw
history blame
7.9 kB
from openai import AzureOpenAI, OpenAI,AsyncAzureOpenAI,AsyncOpenAI
from abc import abstractmethod
import os
import httpx
import base64
import logging
import asyncio
import numpy as np
from tenacity import (
retry,
stop_after_attempt,
wait_fixed,
)
def get_content_between_a_b(start_tag, end_tag, text):
extracted_text = ""
start_index = text.find(start_tag)
while start_index != -1:
end_index = text.find(end_tag, start_index + len(start_tag))
if end_index != -1:
extracted_text += text[start_index + len(start_tag) : end_index] + " "
start_index = text.find(start_tag, end_index + len(end_tag))
else:
break
return extracted_text.strip()
def before_retry_fn(retry_state):
if retry_state.attempt_number > 1:
logging.info(f"Retrying API call. Attempt #{retry_state.attempt_number}, f{retry_state}")
def encode_image(image_path):
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
def get_openai_url(img_pth):
end = img_pth.split(".")[-1]
if end == "jpg":
end = "jpeg"
base64_image = encode_image(img_pth)
return f"data:image/{end};base64,{base64_image}"
class base_llm:
def __init__(self) -> None:
pass
@abstractmethod
def response(self,messages,**kwargs):
pass
def get_imgs(self,prompt, save_path="saves/dalle3.jpg"):
pass
class openai_llm(base_llm):
def __init__(self,model = "gpt4o-0513",deployment = "gpt-4o-0806") -> None:
super().__init__()
self.model = model
if "AZURE_OPENAI_ENDPOINT" not in os.environ or os.environ["AZURE_OPENAI_ENDPOINT"] == "":
raise ValueError("AZURE_OPENAI_ENDPOINT is not set")
if "AZURE_OPENAI_KEY" not in os.environ or os.environ["AZURE_OPENAI_KEY"] == "":
raise ValueError("AZURE_OPENAI_KEY is not set")
api_version = os.environ.get("AZURE_OPENAI_API_VERSION",None)
if api_version == "":
api_version = None
self.client = AzureOpenAI(
azure_deployment= deployment,
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_KEY"],
api_version= api_version
)
self.async_client = AsyncAzureOpenAI(
azure_deployment= deployment,
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
api_key=os.environ["AZURE_OPENAI_KEY"],
api_version= api_version
)
def cal_cosine_similarity(self, vec1, vec2):
if isinstance(vec1, list):
vec1 = np.array(vec1)
if isinstance(vec2, list):
vec2 = np.array(vec2)
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
@retry(wait=wait_fixed(10), stop=stop_after_attempt(10), before=before_retry_fn)
def response(self,messages,**kwargs):
try:
response = self.client.chat.completions.create(
model=kwargs.get("model", self.model),
messages=messages,
n = kwargs.get("n", 1),
temperature= kwargs.get("temperature", 0.7),
max_tokens=kwargs.get("max_tokens", 4000),
timeout=kwargs.get("timeout", 180)
)
except Exception as e:
model = kwargs.get("model", self.model)
print(f"get {model} response failed: {e}")
print(e)
logging.info(e)
return
return response.choices[0].message.content
@retry(wait=wait_fixed(10), stop=stop_after_attempt(10), before=before_retry_fn)
def get_embbeding(self,text):
if os.environ.get("EMBEDDING_API_ENDPOINT"):
client = AzureOpenAI(
azure_endpoint=os.environ.get("EMBEDDING_API_ENDPOINT",None),
api_key=os.environ.get("EMBEDDING_API_KEY",None),
api_version= os.environ.get("AZURE_OPENAI_API_VERSION",None),
azure_deployment="embedding-3-large"
)
else:
client = self.client
try:
embbeding = client.embeddings.create(
model=os.environ.get("EMBEDDING_MODEL","text-embedding-3-large"),
input=text,
timeout= 180
)
embbeding = embbeding.data
if len(embbeding) == 0:
return None
elif len(embbeding) == 1:
return embbeding[0].embedding
else:
return [e.embedding for e in embbeding]
except Exception as e:
print(f"get embbeding failed: {e}")
print(e)
logging.info(e)
return
async def get_embbeding_async(self,text):
if os.environ.get("EMBEDDING_API_ENDPOINT",None):
async_client = AsyncAzureOpenAI(
azure_endpoint=os.environ.get("EMBEDDING_API_ENDPOINT",None),
api_key=os.environ.get("EMBEDDING_API_KEY",None),
api_version= os.environ.get("AZURE_OPENAI_API_VERSION",None),
azure_deployment="embedding-3-large"
)
else:
async_client = self.async_client
try:
embbeding = await async_client.embeddings.create(
model=os.environ.get("EMBEDDING_MODEL","text-embedding-3-large"),
input=text,
timeout= 180
)
embbeding = embbeding.data
if len(embbeding) == 0:
return None
elif len(embbeding) == 1:
return embbeding[0].embedding
else:
return [e.embedding for e in embbeding]
except Exception as e:
await asyncio.sleep(0.1)
print(f"get embbeding failed: {e}")
print(e)
logging.info(e)
return
@retry(wait=wait_fixed(10), stop=stop_after_attempt(10), before=before_retry_fn)
async def response_async(self,messages,**kwargs):
try:
response = await self.async_client.chat.completions.create(
model=kwargs.get("model", self.model),
messages=messages,
n = kwargs.get("n", 1),
temperature= kwargs.get("temperature", 0.7),
max_tokens=kwargs.get("max_tokens", 4000),
timeout=kwargs.get("timeout", 180)
)
except Exception as e:
await asyncio.sleep(0.1)
model = kwargs.get("model", self.model)
print(f"get {model} response failed: {e}")
print(e)
logging.info(e)
return
return response.choices[0].message.content
if __name__ == "__main__":
import os
import yaml
def cal_cosine_similarity_matric(matric1, matric2):
if isinstance(matric1, list):
matric1 = np.array(matric1)
if isinstance(matric2, list):
matric2 = np.array(matric2)
if len(matric1.shape) == 1:
matric1 = matric1.reshape(1, -1)
if len(matric2.shape) == 1:
matric2 = matric2.reshape(1, -1)
dot_product = np.dot(matric1, matric2.T)
norm1 = np.linalg.norm(matric1, axis=1)
norm2 = np.linalg.norm(matric2, axis=1)
cos_sim = dot_product / np.outer(norm1, norm2)
scores = cos_sim.flatten()
# 返回一个list
return scores.tolist()
texts = ["What is the capital of France?","What is the capital of Spain?", "What is the capital of Italy?", "What is the capital of Germany?"]
text = "What is the capital of France?"
llm = openai_llm()
embbedings = llm.get_embbeding(texts)
embbeding = llm.get_embbeding(text)
scores = cal_cosine_similarity_matric(embbedings, embbeding)
print(scores)