Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import faiss | |
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from beir import util | |
from beir.datasets.data_loader import GenericDataLoader | |
from beir import EvaluateRetrieval | |
# Function to load the dataset | |
def load_dataset(): | |
dataset_name = "nq" | |
data_path = f"datasets/{dataset_name}.zip" | |
if not os.path.exists(data_path): | |
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip" | |
util.download_and_unzip(url, "datasets/") | |
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test") | |
return corpus, queries, qrels | |
# Function for candidate retrieval | |
def candidate_retrieval(corpus, queries): | |
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
corpus_ids = list(corpus.keys()) | |
corpus_texts = [corpus[pid]["text"] for pid in corpus_ids] | |
corpus_embeddings = embed_model.encode(corpus_texts, convert_to_numpy=True) | |
index = faiss.IndexFlatL2(corpus_embeddings.shape[1]) | |
index.add(corpus_embeddings) | |
query_texts = [queries[qid] for qid in queries.keys()] | |
query_embeddings = embed_model.encode(query_texts, convert_to_numpy=True) | |
_, retrieved_indices = index.search(query_embeddings, 10) | |
return retrieved_indices, corpus_ids | |
# Function for reranking | |
def rerank_passages(retrieved_indices, corpus, queries): | |
cross_encoder_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2") | |
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2") | |
reranked_passages = [] | |
for i, query in enumerate(queries.values()): | |
query_passage_pairs = [(query, corpus[corpus_ids[idx]]["text"]) for idx in retrieved_indices[i]] | |
inputs = tokenizer(query_passage_pairs, padding=True, truncation=True, return_tensors="pt") | |
scores = cross_encoder_model(**inputs).logits.squeeze(-1) | |
top_reranked_passages = [passage for _, passage in sorted(zip(scores, query_passage_pairs), key=lambda x: x[0], reverse=True)] | |
reranked_passages.append(top_reranked_passages) | |
return reranked_passages | |
# Function for evaluation | |
def evaluate(qrels, retrieved_indices, reranked_passages, queries): | |
evaluator = EvaluateRetrieval() | |
results_stage1 = {} | |
for i, query_id in enumerate(queries.keys()): | |
results_stage1[query_id] = {corpus_ids[idx]: 1 for idx in retrieved_indices[i]} | |
ndcg_score_stage1 = evaluator.evaluate(qrels, results_stage1, [10])['NDCG@10'] | |
results_stage2 = {} | |
for i, query_id in enumerate(queries.keys()): | |
results_stage2[query_id] = {} | |
for passage in reranked_passages[i]: | |
for pid, doc in corpus.items(): | |
if doc["text"] == passage[1]: | |
results_stage2[query_id][pid] = 1 | |
break | |
ndcg_score_stage2 = evaluator.evaluate(qrels, results_stage2, [10])['NDCG@10'] | |
return ndcg_score_stage1, ndcg_score_stage2 | |
# Streamlit app | |
def main(): | |
st.title("Multi-Stage Text Retrieval Pipeline") | |
if st.button("Load Dataset"): | |
corpus, queries, qrels = load_dataset() | |
st.success("Dataset loaded successfully!") | |
if st.button("Run Candidate Retrieval"): | |
retrieved_indices, corpus_ids = candidate_retrieval(corpus, queries) | |
st.success("Candidate retrieval completed!") | |
st.write("Retrieved indices:", retrieved_indices) | |
if st.button("Run Reranking"): | |
reranked_passages = rerank_passages(retrieved_indices, corpus, queries) | |
st.success("Reranking completed!") | |
st.write("Reranked passages:", reranked_passages) | |
if st.button("Evaluate"): | |
ndcg_score_stage1, ndcg_score_stage2 = evaluate(qrels, retrieved_indices, reranked_passages, queries) | |
st.write(f"NDCG@10 for Stage 1 (Candidate Retrieval): {ndcg_score_stage1}") | |
st.write(f"NDCG@10 for Stage 2 (Reranking): {ndcg_score_stage2}") | |
if __name__ == "__main__": | |
main() | |