Hyma7's picture
Update app.py
8dc7a1c verified
import streamlit as st
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import numpy as np
from sklearn.metrics import ndcg_score
# Helper function to load the dataset
def download_and_extract_dataset():
import urllib.request
import zipfile
import os
dataset_url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip"
dataset_zip_path = "nq.zip"
data_path = "./datasets/nq"
# Download the dataset if not already downloaded
if not os.path.exists(dataset_zip_path):
st.write("Downloading the dataset... This may take a few minutes.")
urllib.request.urlretrieve(dataset_url, dataset_zip_path)
st.write("Download complete!")
# Unzip the dataset if not already unzipped
if not os.path.exists(data_path):
st.write("Unzipping the dataset...")
with zipfile.ZipFile(dataset_zip_path, 'r') as zip_ref:
zip_ref.extractall("./datasets")
st.write("Dataset unzipped!")
return data_path
# Function to load corpus, queries, and qrels
def load_dataset():
from beir.datasets.data_loader import GenericDataLoader
data_path = download_and_extract_dataset()
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
return corpus, queries, qrels
# Stage 1: Candidate retrieval using Sentence Transformer
def candidate_retrieval(query, corpus, top_k=10):
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
corpus_ids = list(corpus.keys())
corpus_embeddings = model.encode([corpus[doc_id]['text'] for doc_id in corpus_ids], convert_to_tensor=True)
query_embedding = model.encode(query, convert_to_tensor=True)
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)[0]
retrieved_docs = [corpus_ids[hit['corpus_id']] for hit in hits]
return retrieved_docs
# Stage 2: Reranking using cross-encoder
def rerank(retrieved_docs, query, corpus, top_k=5):
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
scores = []
for doc_id in retrieved_docs:
text = corpus[doc_id]['text']
inputs = tokenizer(query, text, return_tensors="pt", truncation=True, padding=True)
outputs = model(**inputs)
scores.append(outputs.logits.item())
reranked_indices = np.argsort(scores)[::-1][:top_k]
reranked_docs = [retrieved_docs[idx] for idx in reranked_indices]
return reranked_docs, scores
# Function to evaluate using NDCG@10
def evaluate_ndcg(reranked_docs, qrels, query_id, k=10):
true_relevance = [qrels.get((query_id, doc_id), 0) for doc_id in reranked_docs]
ideal_relevance = sorted(true_relevance, reverse=True)
# NDCG expects input as 2D arrays
return ndcg_score([ideal_relevance], [true_relevance], k=k)
# Streamlit main function
def main():
st.title("Multi-Stage Retrieval Pipeline with Evaluation")
st.write("Loading the dataset...")
corpus, queries, qrels = load_dataset()
st.write(f"Corpus Size: {len(corpus)}")
# User input for asking a question
user_query = st.text_input("Ask a question:")
if user_query:
st.write(f"Your query: {user_query}")
st.write("Running Candidate Retrieval...")
retrieved_docs = candidate_retrieval(user_query, corpus, top_k=10)
st.write("Running Reranking...")
reranked_docs, rerank_scores = rerank(retrieved_docs, user_query, corpus, top_k=5)
st.write("Top Reranked Documents:")
for doc_id in reranked_docs:
st.write(f"Document ID: {doc_id}")
st.write(f"Document Text: {corpus[doc_id]['text'][:500]}...") # Show the first 500 characters of the document
# Evaluation if the user query exists in the qrels (ground truth relevance labels)
query_id = list(queries.keys())[0] # Dummy query ID for now
if query_id in queries:
ndcg_score_value = evaluate_ndcg(reranked_docs, qrels, query_id, k=10)
st.write(f"NDCG@10 Score: {ndcg_score_value}")
else:
st.write("No ground truth available for this query.")
st.write("Query executed successfully!")
if __name__ == "__main__":
main()