ADKU's picture
Update app.py
a91f0db verified
raw
history blame
3.54 kB
import os
import faiss
import numpy as np
from rank_bm25 import BM25Okapi
import torch
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, AutoTokenizer, AutoModel
# Set Hugging Face cache directory
os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
app = FastAPI()
# Ensure the correct file path
DATASET_PATH = os.path.join(os.getcwd(), "springer_papers_DL.json")
if not os.path.exists(DATASET_PATH):
raise FileNotFoundError(f"Dataset file not found at {DATASET_PATH}")
# Load dataset
df = pd.read_json(DATASET_PATH)
# Clean text function
def clean_text(text):
return text.strip().lower()
df['cleaned_abstract'] = df['abstract'].apply(clean_text)
# Precompute BM25 Index
tokenized_corpus = [paper.split() for paper in df["cleaned_abstract"]]
bm25 = BM25Okapi(tokenized_corpus)
# Load FAISS model
embedding_model = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(embedding_model)
model = AutoModel.from_pretrained(embedding_model)
# Generate embeddings using SciBERT
def generate_embeddings_sci_bert(texts, batch_size=32):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i:i + batch_size]
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
all_embeddings.append(embeddings.cpu().numpy())
return np.concatenate(all_embeddings, axis=0)
# Compute document embeddings
abstracts = df["cleaned_abstract"].tolist()
embeddings = generate_embeddings_sci_bert(abstracts, batch_size=32)
# Initialize FAISS index
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings.astype(np.float32))
# API Request Model
class InputText(BaseModel):
query: str
top_k: int = 5
@app.post("/predict/")
async def predict(data: InputText):
query = data.query
top_k = data.top_k
if not query.strip():
return {"error": "Query is empty. Please enter a valid search query."}
# 1️⃣ Generate embedding for query
query_embedding = generate_embeddings_sci_bert([query], batch_size=1)
# 2️⃣ Perform FAISS similarity search
distances, indices = faiss_index.search(query_embedding.astype(np.float32), top_k)
# 3️⃣ Perform BM25 keyword search
tokenized_query = query.split()
bm25_scores = bm25.get_scores(tokenized_query)
bm25_top_indices = np.argsort(bm25_scores)[::-1][:top_k]
# 4️⃣ Combine FAISS and BM25 results
combined_indices = list(set(indices[0]) | set(bm25_top_indices))
ranked_results = sorted(combined_indices, key=lambda idx: -bm25_scores[idx])
# 5️⃣ Retrieve research papers
relevant_papers = []
for i, index in enumerate(ranked_results[:top_k]):
paper = df.iloc[index]
relevant_papers.append({
"rank": i + 1,
"title": paper["title"],
"authors": paper["authors"],
"abstract": paper["cleaned_abstract"]
})
return {"results": relevant_papers}
# Run FastAPI
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0")