inksiyu's picture
Upload 14 files
c871381 verified
import requests
import json
import logging
from typing import List, Tuple
import pandas as pd
import uuid
# 设置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# OpenAI API配置
API_KEY = ""
API_URL = ""
MODEL = "text-embedding-ada-002"
proxies = {
"http": "http://127.0.0.1:10808",
"https": "http://127.0.0.1:10808"
}
def generate_embeddings_from_qa_pairs(qa_pairs: List[Tuple[str, str]], api_key, api_url_base):
global API_KEY, API_URL
API_KEY = api_key
API_URL = f"{api_url_base}/embeddings"
df = process_qa_pairs(qa_pairs)
random_filename = f"output/qa_embeddings_{uuid.uuid4().hex}.csv"
df.to_csv(random_filename, index=False)
print(f"Embeddings saved to {random_filename}")
print(df)
def generate_embeddings(text: str) -> List[float]:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
data = {
"input": text,
"model": MODEL
}
try:
response = requests.post(API_URL, headers=headers, json=data)
response.raise_for_status()
embedding = response.json()["data"][0]["embedding"]
except requests.exceptions.RequestException as e:
logging.error(f"OpenAI API request failed: {e}")
return []
return embedding
def process_qa_pairs(qa_pairs: List[Tuple[str, str]]) -> pd.DataFrame:
data = []
for q, a in qa_pairs:
combined_text = f"Question: {q}\nAnswer: {a}"
embedding = generate_embeddings(combined_text)
data.append({"Question": q, "Answer": a, "Combined_Text": combined_text, "Embedding": embedding})
df = pd.DataFrame(data)
return df
if __name__ == "__main__":
qa_pairs = [
("秦始皇在中国历史的地位如何?", "秦始皇在中国历史上具有极其重要的地位,他建立了第一个中央集权国家,也是中国第一位称皇帝的君主。他在十年时间里兼并六国,结束了春秋战国五百年分裂局面,使天下归于一统。他成功的把原本七零八散的地域文明重新整合在一起,成为中国此后两千多年封建文明的基石。统一的疆域、统一的文明、统一的法制和中央集权,使中华民族以一个整体骄傲的屹立在世界东方。"),
("秦始皇在世界历史上的地位如何?", "秦始皇在世界史上也具有一定的地位。他建立了中国历史上第一个大一统的中央集权国家,结束了兵荒马乱的战国时代。他的统一疆域、统一文明、统一法制和中央集权的成就,使得中华民族以一个整体骄傲的屹立在世界东方。"),
("Minecraft是什么类型的游戏?", "《我的世界》(Minecraft)是一款沙盒类电子游戏。")
]
df = process_qa_pairs(qa_pairs)
random_filename = f"output/qa_embeddings_{uuid.uuid4().hex}.csv"
df.to_csv(random_filename, index=False)
print(f"Embeddings saved to {random_filename}")
print(df)