Spaces:
Build error
Build error
import faiss | |
import numpy as np | |
from FlagEmbedding import FlagModel | |
from flask import Flask, request, jsonify | |
from datasets import load_dataset | |
import gradio as gr | |
import os | |
import time | |
from functools import lru_cache | |
# Initialize components | |
app = Flask(__name__) | |
model = None | |
index = None | |
corpus = None | |
def initialize_components(): | |
global model, index, corpus | |
# Load model with safety checks | |
if model is None: | |
model = FlagModel( | |
"BAAI/bge-large-en-v1.5", | |
query_instruction_for_retrieval="Represent this sentence for searching relevant passages:", | |
use_fp16=True | |
) | |
# Load corpus from Hugging Face dataset | |
if corpus is None: | |
dataset = load_dataset("awinml/medrag_corpus_sampled", split='train') | |
corpus = [f"{row['id']}\t{row['contents']}" for row in dataset] | |
# Create FAISS index in memory | |
if index is None: | |
embeddings = model.encode([doc.split('\t', 1)[1] for doc in corpus]) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatIP(dimension) | |
index.add(embeddings.astype('float32')) | |
def retrieve(): | |
start_time = time.time() | |
# Validate request | |
data = request.json | |
if not data or "queries" not in data: | |
return jsonify({"error": "Missing 'queries' parameter"}), 400 | |
# Initialize components if needed | |
initialize_components() | |
# Process queries | |
queries = data["queries"] | |
topk = data.get("topk", 3) | |
return_scores = data.get("return_scores", False) | |
# Batch processing | |
query_embeddings = model.encode_queries(queries) | |
scores, indices = index.search(query_embeddings.astype('float32'), topk) | |
# Format results | |
results = [] | |
for i, query in enumerate(queries): | |
query_results = [] | |
for j in range(topk): | |
doc_idx = indices[i][j] | |
doc = corpus[doc_idx] | |
doc_id, content = doc.split('\t', 1) | |
result = { | |
"document": { | |
"id": doc_id, | |
"contents": content | |
}, | |
"score": float(scores[i][j]) | |
} | |
query_results.append(result) | |
results.append(query_results) | |
return jsonify({ | |
"result": results, | |
"time": f"{time.time() - start_time:.2f}s" | |
}) | |
# Gradio UI for testing | |
def gradio_interface(query, topk): | |
response = requests.post( | |
"http://localhost:7860/retrieve", | |
json={"queries": [query], "topk": topk} | |
) | |
return response.json()["result"][0] | |
# Start server | |
if __name__ == "__main__": | |
# First-time initialization | |
initialize_components() | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Textbox(label="Medical Query", placeholder="Enter your medical question..."), | |
gr.Slider(1, 10, value=3, label="Top Results") | |
], | |
outputs=gr.JSON(label="Retrieval Results"), | |
title="Medical Retrieval System", | |
description="Search across medical literature using AI-powered semantic search" | |
) | |
# Run both Flask and Gradio | |
iface.launch(server_name="0.0.0.0", server_port=7860, share=True) |