Spaces:
Runtime error
Runtime error
import requests | |
import logging | |
from typing import List, Tuple | |
import pandas as pd | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
from typing import List, Tuple, Dict | |
# 设置日志记录 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# OpenAI API配置 | |
API_KEY = "sk-u0S4iYA2kJmaDNBgBc48D2A6Fa904fF0B6E19dF0F6A39717" | |
API_URL = "https://api.ltcld.cn/v1/embeddings" | |
MODEL = "text-embedding-ada-002" | |
def generate_embeddings(text: str) -> List[float]: | |
headers = { | |
"Content-Type": "application/json", | |
"Authorization": f"Bearer {API_KEY}" | |
} | |
data = { | |
"input": text, | |
"model": MODEL | |
} | |
try: | |
response = requests.post(API_URL, headers=headers, json=data) | |
response.raise_for_status() | |
embedding = response.json()["data"][0]["embedding"] | |
except requests.exceptions.RequestException as e: | |
logging.error(f"OpenAI API request failed: {e}") | |
return [] | |
return embedding | |
def rerank_results(query: str, documents: List[str]) -> List[Tuple[str, float]]: | |
vectorizer = TfidfVectorizer() | |
tfidf_matrix = vectorizer.fit_transform(documents) | |
query_vector = vectorizer.transform([query]) | |
similarity_scores = cosine_similarity(query_vector, tfidf_matrix).flatten() | |
reranked_results = list(zip(documents, similarity_scores)) | |
reranked_results.sort(key=lambda x: x[1], reverse=True) | |
return reranked_results | |
def search_dataset(queries: List[str], top_k: int = 5, similarity_threshold: float = 0.5) -> List[dict]: | |
results = [] | |
for query in queries: | |
query_embedding = generate_embeddings(query) | |
embeddings = df['Embedding'].tolist() | |
similarity_scores = cosine_similarity([query_embedding], embeddings)[0] | |
df['Similarity'] = similarity_scores | |
print(f"Similarity scores for query '{query}': {similarity_scores}") | |
top_results = df.sort_values('Similarity', ascending=False).head(top_k) | |
print(f"Top {top_k} results for query '{query}':") | |
print(top_results) | |
query_results = [] | |
for _, row in top_results.iterrows(): | |
if row['Similarity'] >= similarity_threshold: | |
query_results.append({ | |
'question': row['Question'], | |
'answer': row['Answer'], | |
'similarity': row['Similarity'] | |
}) | |
print(f"Filtered results for query '{query}': {query_results}") | |
results.append(query_results) | |
merged_results = [] | |
for query_results in results: | |
merged_results.extend(query_results) | |
print(f"Merged results: {merged_results}") | |
for query in queries: | |
documents = [result['question'] + ' ' + result['answer'] for result in merged_results] | |
reranked_results = rerank_results(query, documents) | |
final_results = [] | |
for doc, score in reranked_results: | |
for result in merged_results: | |
if doc == result['question'] + ' ' + result['answer']: | |
result['score'] = score | |
final_results.append(result) | |
break | |
unique_results = [] | |
seen_questions = set() | |
seen_answers = set() | |
for result in final_results: | |
if result['question'] not in seen_questions and result['answer'] not in seen_answers: | |
unique_results.append(result) | |
seen_questions.add(result['question']) | |
seen_answers.add(result['answer']) | |
print(f"Unique results: {unique_results}") | |
filtered_results = [result for result in unique_results if result['similarity'] >= similarity_threshold] | |
print(f"Filtered results: {filtered_results}") | |
return filtered_results | |
df = pd.read_csv('output/qa_embeddings.csv') | |
df['Embedding'] = df['Embedding'].apply(eval) | |
# search_queries = ["原神","minecraft"] | |
# search_results = search_dataset(search_queries, top_k=1, similarity_threshold=0.5) | |
# for i, result in enumerate(search_results): | |
# print(f"Search Result {i+1}:") | |
# print(f"Question: {result['question']}") | |
# print(f"Answer: {result['answer']}") | |
# print(f"Similarity: {result['similarity']}") | |
# print(f"Rerank Score: {result['score']}") | |
# print("----------------------------------------------------") |