Hyma7's picture
Update app.py
f782117 verified
raw
history blame
4.15 kB
import os
import numpy as np
import faiss
import streamlit as st
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir import EvaluateRetrieval
# Function to load the dataset
def load_dataset():
dataset_name = "nq"
data_path = f"datasets/{dataset_name}.zip"
if not os.path.exists(data_path):
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset_name}.zip"
util.download_and_unzip(url, "datasets/")
corpus, queries, qrels = GenericDataLoader(data_path).load(split="test")
return corpus, queries, qrels
# Function for candidate retrieval
def candidate_retrieval(corpus, queries):
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
corpus_ids = list(corpus.keys())
corpus_texts = [corpus[pid]["text"] for pid in corpus_ids]
corpus_embeddings = embed_model.encode(corpus_texts, convert_to_numpy=True)
index = faiss.IndexFlatL2(corpus_embeddings.shape[1])
index.add(corpus_embeddings)
query_texts = [queries[qid] for qid in queries.keys()]
query_embeddings = embed_model.encode(query_texts, convert_to_numpy=True)
_, retrieved_indices = index.search(query_embeddings, 10)
return retrieved_indices, corpus_ids
# Function for reranking
def rerank_passages(retrieved_indices, corpus, queries):
cross_encoder_model = AutoModelForSequenceClassification.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
tokenizer = AutoTokenizer.from_pretrained("cross-encoder/ms-marco-MiniLM-L-12-v2")
reranked_passages = []
for i, query in enumerate(queries.values()):
query_passage_pairs = [(query, corpus[corpus_ids[idx]]["text"]) for idx in retrieved_indices[i]]
inputs = tokenizer(query_passage_pairs, padding=True, truncation=True, return_tensors="pt")
scores = cross_encoder_model(**inputs).logits.squeeze(-1)
top_reranked_passages = [passage for _, passage in sorted(zip(scores, query_passage_pairs), key=lambda x: x[0], reverse=True)]
reranked_passages.append(top_reranked_passages)
return reranked_passages
# Function for evaluation
def evaluate(qrels, retrieved_indices, reranked_passages, queries):
evaluator = EvaluateRetrieval()
results_stage1 = {}
for i, query_id in enumerate(queries.keys()):
results_stage1[query_id] = {corpus_ids[idx]: 1 for idx in retrieved_indices[i]}
ndcg_score_stage1 = evaluator.evaluate(qrels, results_stage1, [10])['NDCG@10']
results_stage2 = {}
for i, query_id in enumerate(queries.keys()):
results_stage2[query_id] = {}
for passage in reranked_passages[i]:
for pid, doc in corpus.items():
if doc["text"] == passage[1]:
results_stage2[query_id][pid] = 1
break
ndcg_score_stage2 = evaluator.evaluate(qrels, results_stage2, [10])['NDCG@10']
return ndcg_score_stage1, ndcg_score_stage2
# Streamlit app
def main():
st.title("Multi-Stage Text Retrieval Pipeline")
if st.button("Load Dataset"):
corpus, queries, qrels = load_dataset()
st.success("Dataset loaded successfully!")
if st.button("Run Candidate Retrieval"):
retrieved_indices, corpus_ids = candidate_retrieval(corpus, queries)
st.success("Candidate retrieval completed!")
st.write("Retrieved indices:", retrieved_indices)
if st.button("Run Reranking"):
reranked_passages = rerank_passages(retrieved_indices, corpus, queries)
st.success("Reranking completed!")
st.write("Reranked passages:", reranked_passages)
if st.button("Evaluate"):
ndcg_score_stage1, ndcg_score_stage2 = evaluate(qrels, retrieved_indices, reranked_passages, queries)
st.write(f"NDCG@10 for Stage 1 (Candidate Retrieval): {ndcg_score_stage1}")
st.write(f"NDCG@10 for Stage 2 (Reranking): {ndcg_score_stage2}")
if __name__ == "__main__":
main()