Spaces:
Running
Running
from openai import AzureOpenAI, OpenAI,AsyncAzureOpenAI,AsyncOpenAI | |
from abc import abstractmethod | |
import os | |
import httpx | |
import base64 | |
import logging | |
import asyncio | |
import numpy as np | |
from tenacity import ( | |
retry, | |
stop_after_attempt, | |
wait_fixed, | |
) | |
def get_content_between_a_b(start_tag, end_tag, text): | |
extracted_text = "" | |
start_index = text.find(start_tag) | |
while start_index != -1: | |
end_index = text.find(end_tag, start_index + len(start_tag)) | |
if end_index != -1: | |
extracted_text += text[start_index + len(start_tag) : end_index] + " " | |
start_index = text.find(start_tag, end_index + len(end_tag)) | |
else: | |
break | |
return extracted_text.strip() | |
def before_retry_fn(retry_state): | |
if retry_state.attempt_number > 1: | |
logging.info(f"Retrying API call. Attempt #{retry_state.attempt_number}, f{retry_state}") | |
def encode_image(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
def get_openai_url(img_pth): | |
end = img_pth.split(".")[-1] | |
if end == "jpg": | |
end = "jpeg" | |
base64_image = encode_image(img_pth) | |
return f"data:image/{end};base64,{base64_image}" | |
class base_llm: | |
def __init__(self) -> None: | |
pass | |
def response(self,messages,**kwargs): | |
pass | |
def get_imgs(self,prompt, save_path="saves/dalle3.jpg"): | |
pass | |
class openai_llm(base_llm): | |
def __init__(self,model = None,deployment = None,endpoint=None,api_key = None) -> None: | |
super().__init__() | |
self.model = model | |
api_version= "2024-02-15-preview" | |
if api_version == "": | |
api_version = None | |
self.client = AzureOpenAI( | |
azure_deployment= deployment, | |
azure_endpoint=endpoint, | |
api_key=api_key, | |
api_version= api_version | |
) | |
self.async_client = AsyncAzureOpenAI( | |
azure_deployment= deployment, | |
azure_endpoint=endpoint, | |
api_key=api_key, | |
api_version= api_version | |
) | |
def cal_cosine_similarity(self, vec1, vec2): | |
if isinstance(vec1, list): | |
vec1 = np.array(vec1) | |
if isinstance(vec2, list): | |
vec2 = np.array(vec2) | |
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) | |
def response(self,messages,**kwargs): | |
try: | |
response = self.client.chat.completions.create( | |
model=kwargs.get("model", self.model), | |
messages=messages, | |
n = kwargs.get("n", 1), | |
temperature= kwargs.get("temperature", 0.7), | |
max_tokens=kwargs.get("max_tokens", 4000), | |
timeout=kwargs.get("timeout", 180) | |
) | |
except Exception as e: | |
model = kwargs.get("model", self.model) | |
print(f"get {model} response failed: {e}") | |
print(e) | |
logging.info(e) | |
return | |
return response.choices[0].message.content | |
def get_embbeding(self,text): | |
if os.environ.get("EMBEDDING_API_ENDPOINT"): | |
client = AzureOpenAI( | |
azure_endpoint=os.environ.get("EMBEDDING_API_ENDPOINT",None), | |
api_key=os.environ.get("EMBEDDING_API_KEY",None), | |
api_version= "2024-02-15-preview", | |
azure_deployment="text-embedding-3-large" | |
) | |
else: | |
client = self.client | |
try: | |
embbeding = client.embeddings.create( | |
model=os.environ.get("EMBEDDING_MODEL","text-embedding-3-large"), | |
input=text, | |
timeout= 180 | |
) | |
embbeding = embbeding.data | |
if len(embbeding) == 0: | |
return None | |
elif len(embbeding) == 1: | |
return embbeding[0].embedding | |
else: | |
return [e.embedding for e in embbeding] | |
except Exception as e: | |
print(f"get embbeding failed: {e}") | |
print(e) | |
logging.info(e) | |
return | |
async def get_embbeding_async(self,text): | |
if os.environ.get("EMBEDDING_API_ENDPOINT",None): | |
async_client = AsyncAzureOpenAI( | |
azure_endpoint=os.environ.get("EMBEDDING_API_ENDPOINT",None), | |
api_key=os.environ.get("EMBEDDING_API_KEY",None), | |
api_version= "2024-02-15-preview", | |
azure_deployment="text-embedding-3-large" | |
) | |
else: | |
async_client = self.async_client | |
try: | |
embbeding = await async_client.embeddings.create( | |
model=os.environ.get("EMBEDDING_MODEL","text-embedding-3-large"), | |
input=text, | |
timeout= 180 | |
) | |
embbeding = embbeding.data | |
if len(embbeding) == 0: | |
return None | |
elif len(embbeding) == 1: | |
return embbeding[0].embedding | |
else: | |
return [e.embedding for e in embbeding] | |
except Exception as e: | |
await asyncio.sleep(0.1) | |
print(f"get embbeding failed: {e}") | |
print(e) | |
logging.info(e) | |
return | |
async def response_async(self,messages,**kwargs): | |
try: | |
response = await self.async_client.chat.completions.create( | |
model=kwargs.get("model", self.model), | |
messages=messages, | |
n = kwargs.get("n", 1), | |
temperature= kwargs.get("temperature", 0.7), | |
max_tokens=kwargs.get("max_tokens", 4000), | |
timeout=kwargs.get("timeout", 180) | |
) | |
except Exception as e: | |
await asyncio.sleep(0.1) | |
model = kwargs.get("model", self.model) | |
print(f"get {model} response failed: {e}") | |
print(e) | |
logging.info(e) | |
return | |
return response.choices[0].message.content | |
if __name__ == "__main__": | |
import os | |
import yaml | |
def cal_cosine_similarity_matric(matric1, matric2): | |
if isinstance(matric1, list): | |
matric1 = np.array(matric1) | |
if isinstance(matric2, list): | |
matric2 = np.array(matric2) | |
if len(matric1.shape) == 1: | |
matric1 = matric1.reshape(1, -1) | |
if len(matric2.shape) == 1: | |
matric2 = matric2.reshape(1, -1) | |
dot_product = np.dot(matric1, matric2.T) | |
norm1 = np.linalg.norm(matric1, axis=1) | |
norm2 = np.linalg.norm(matric2, axis=1) | |
cos_sim = dot_product / np.outer(norm1, norm2) | |
scores = cos_sim.flatten() | |
# 返回一个list | |
return scores.tolist() | |
texts = ["What is the capital of France?","What is the capital of Spain?", "What is the capital of Italy?", "What is the capital of Germany?"] | |
text = "What is the capital of France?" | |
llm = openai_llm() | |
embbedings = llm.get_embbeding(texts) | |
embbeding = llm.get_embbeding(text) | |
scores = cal_cosine_similarity_matric(embbedings, embbeding) | |
print(scores) |