final_book_retriever / final_book_retriever.py
achdaisy's picture
Update final_book_retriever.py
a4dde3f verified
import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
# Initialize BERT embeddings model
model_name = "BAAI/bge-small-en-v1.5"
encode_kwargs = {'normalize_embeddings': True} # Set True to compute cosine similarity
embeddings_model = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs={'device': 'cpu'},
encode_kwargs=encode_kwargs
)
# Read CSV file
data = pd.read_csv(r'books.csv', encoding='latin1')
def retrieve_documents(query):
documents = data['Book Title'].tolist()
return documents
def answer(query, min_similarity=0.7):
# Retrieve documents
retrieved_documents = retrieve_documents(query)
# Embed query
embedded_query = embeddings_model.embed_query(query)
# Embed documents
embedded_documents = embeddings_model.embed_documents(retrieved_documents)
# Calculate cosine similarity between query and documents
similarities = cosine_similarity([embedded_query], embedded_documents)
# Rank documents based on similarity scores
ranked_indices = np.argsort(similarities[0])[::-1]
# Retrieve document details for documents with similarity score greater than min_similarity
ranked_documents = []
for index in ranked_indices:
similarity_score = round(similarities[0][index], 2) # Round similarity score to two decimal places
if similarity_score > min_similarity:
document_details = {
"Book": data['Book Title'].iloc[index],
"Author": data['Author'].iloc[index],
"Edition": data['Edition'].iloc[index],
"File Name": data['File_name'].iloc[index],
"Similarity Score": similarity_score
}
ranked_documents.append(document_details)
else:
# Since documents are ranked in descending order of similarity, break the loop when similarity score falls below min_similarity
break
if not ranked_documents:
print("No similar books found")
return ranked_documents
# Example usage
#query = "machine learning"
#result = answer(query)
#print(result)