File size: 3,298 Bytes
d56a4e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import faiss
import numpy as np
from FlagEmbedding import FlagModel
from flask import Flask, request, jsonify
from datasets import load_dataset
import gradio as gr
import os
import time
from functools import lru_cache

# Initialize components
app = Flask(__name__)
model = None
index = None
corpus = None

def initialize_components():
    global model, index, corpus
    
    # Load model with safety checks
    if model is None:
        model = FlagModel(
            "BAAI/bge-large-en-v1.5",
            query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
            use_fp16=True
        )
    
    # Load corpus from Hugging Face dataset
    if corpus is None:
        dataset = load_dataset("awinml/medrag_corpus_sampled", split='train')
        corpus = [f"{row['id']}\t{row['contents']}" for row in dataset]
    
    # Create FAISS index in memory
    if index is None:
        embeddings = model.encode([doc.split('\t', 1)[1] for doc in corpus])
        dimension = embeddings.shape[1]
        index = faiss.IndexFlatIP(dimension)
        index.add(embeddings.astype('float32'))

@app.route("/retrieve", methods=["POST"])
def retrieve():
    start_time = time.time()
    
    # Validate request
    data = request.json
    if not data or "queries" not in data:
        return jsonify({"error": "Missing 'queries' parameter"}), 400
    
    # Initialize components if needed
    initialize_components()
    
    # Process queries
    queries = data["queries"]
    topk = data.get("topk", 3)
    return_scores = data.get("return_scores", False)
    
    # Batch processing
    query_embeddings = model.encode_queries(queries)
    scores, indices = index.search(query_embeddings.astype('float32'), topk)
    
    # Format results
    results = []
    for i, query in enumerate(queries):
        query_results = []
        for j in range(topk):
            doc_idx = indices[i][j]
            doc = corpus[doc_idx]
            doc_id, content = doc.split('\t', 1)
            result = {
                "document": {
                    "id": doc_id,
                    "contents": content
                },
                "score": float(scores[i][j])
            }
            query_results.append(result)
        results.append(query_results)
    
    return jsonify({
        "result": results,
        "time": f"{time.time() - start_time:.2f}s"
    })

# Gradio UI for testing
def gradio_interface(query, topk):
    response = requests.post(
        "http://localhost:7860/retrieve",
        json={"queries": [query], "topk": topk}
    )
    return response.json()["result"][0]

# Start server
if __name__ == "__main__":
    # First-time initialization
    initialize_components()
    
    # Create Gradio interface
    iface = gr.Interface(
        fn=gradio_interface,
        inputs=[
            gr.Textbox(label="Medical Query", placeholder="Enter your medical question..."),
            gr.Slider(1, 10, value=3, label="Top Results")
        ],
        outputs=gr.JSON(label="Retrieval Results"),
        title="Medical Retrieval System",
        description="Search across medical literature using AI-powered semantic search"
    )
    
    # Run both Flask and Gradio
    iface.launch(server_name="0.0.0.0", server_port=7860, share=True)