Spaces:
Runtime error
Runtime error
import numpy as np | |
import pandas as pd | |
from sklearn.metrics.pairwise import cosine_similarity | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
# Initialize BERT embeddings model | |
model_name = "BAAI/bge-small-en-v1.5" | |
encode_kwargs = {'normalize_embeddings': True} # Set True to compute cosine similarity | |
embeddings_model = HuggingFaceBgeEmbeddings( | |
model_name=model_name, | |
model_kwargs={'device': 'cpu'}, | |
encode_kwargs=encode_kwargs | |
) | |
# Read CSV file | |
data = pd.read_csv(r'books.csv', encoding='latin1') | |
def retrieve_documents(query): | |
documents = data['Book Title'].tolist() | |
return documents | |
def answer(query, min_similarity=0.7): | |
# Retrieve documents | |
retrieved_documents = retrieve_documents(query) | |
# Embed query | |
embedded_query = embeddings_model.embed_query(query) | |
# Embed documents | |
embedded_documents = embeddings_model.embed_documents(retrieved_documents) | |
# Calculate cosine similarity between query and documents | |
similarities = cosine_similarity([embedded_query], embedded_documents) | |
# Rank documents based on similarity scores | |
ranked_indices = np.argsort(similarities[0])[::-1] | |
# Retrieve document details for documents with similarity score greater than min_similarity | |
ranked_documents = [] | |
for index in ranked_indices: | |
similarity_score = similarities[0][index] | |
if similarity_score > min_similarity: | |
document_details = { | |
"Book": data['Book Title'].iloc[index], | |
"Author": data['Author'].iloc[index], | |
"Edition": data['Edition'].iloc[index], | |
"File Name": data['File_name'].iloc[index] | |
#"Similarity Score": similarity_score | |
} | |
ranked_documents.append(document_details) | |
else: | |
# Since documents are ranked in descending order of similarity, break the loop when similarity score falls below min_similarity | |
break | |
if not ranked_documents: | |
print("No similar books found") | |
return ranked_documents | |
# Example usage | |
#query = "machine learning" | |
#result = answer(query) | |
#print(result) | |