File size: 6,755 Bytes
12413e3
5d80c84
 
12413e3
 
5d80c84
12413e3
 
5d80c84
12413e3
 
5d80c84
12413e3
 
5d80c84
12413e3
 
 
 
 
 
 
 
 
5d80c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Set up a Python virtual environment and activate it:
# python3 -m venv myenv
# source myenv/bin/activate

# Run the FastAPI application with live reload for development:
# uvicorn main:app --reload

# Run the FastAPI application in the background for production:
# nohup uvicorn main:app --host 0.0.0.0 --port 8000 &

# Check which process is using a specific port (8000 in this case):
# lsof -i :8000

# Terminate a process using a specific port by its PID (e.g., PID 2540):
# kill -9 2540

# Example of a POST request from Postman or any other HTTP client:
# This request indexes data from a specified URL:
# You can call this endpoint by sending a POST request to:
# http://your_server_url/index_fda_drugs?url=https://download.open.fda.gov/drug/label/drug-label-0001-of-0012.json.zip
# where the URL is passed as a query parameter.

# Examples for testing the endpoint locally using curl:
# Local URL testing with curl:
# curl -X POST "http://127.0.0.1:8000/index_fda_drugs?url=https://download.open.fda.gov/drug/label/drug-label-0001-of-0012.json.zip"

import asyncio
import time
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
import pandas as pd
import zipfile
import io
import requests
import json
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain_community.document_loaders import DataFrameLoader
import uuid
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Query

app = FastAPI()

# Load environment variables from a .env file
load_dotenv()

# Set up Qdrant client and embedding model
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
QDRANT_CLUSTER_URL = os.environ.get("QDRANT_CLUSTER_URL")
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
client = AsyncQdrantClient(QDRANT_CLUSTER_URL, api_key=QDRANT_API_KEY)

async def create_collection():
    try:
        collection_info = await client.get_collection(collection_name="fda_drugs")
        print(f"Collection 'fda_drugs' already exists.")
    except Exception as e:
        print(f"Collection 'fda_drugs' does not exist. Creating...")
        collection_info = await client.create_collection(
            collection_name="fda_drugs",
            vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE)
        )
        print(f"Collection 'fda_drugs' created: {collection_info}")

async def index_batch(batch_docs, metadata_fields):
    points = []
    for doc in batch_docs:
        try:
            vector = embedding_model.embed_query(doc.page_content)
            if vector is not None:
                payload = {field: doc.metadata.get(field, '') for field in metadata_fields}
                payload["page_content"] = doc.page_content
                points.append(models.PointStruct(
                    id=str(uuid.uuid4()),
                    payload=payload,
                    vector=vector,
                ))
        except Exception as e:
            print(f"Failed to index document: {e}")
    
    if points:
        try:
            response = await client.upsert(
                collection_name="fda_drugs",
                points=points,
            )
            return len(points)
        except Exception as e:
            print(f"Failed to upsert batch: {e}")
    
    return 0

@app.post("/index_fda_drugs")
async def index_fda_drugs(url: str = Query(..., description="URL of the ZIP file to index")):
    try:
        start_time = time.time()  # Start timing

        # Create or recreate the collection
        await create_collection()
        
        # Download and load data
        response = requests.get(url)
        zip_file = zipfile.ZipFile(io.BytesIO(response.content))
        json_file = zip_file.open(zip_file.namelist()[0])
        data = json.load(json_file)
        df = pd.json_normalize(data['results'])
        selected_drugs = df
        
        # Define metadata fields to include
        metadata_fields = ['openfda.brand_name', 'openfda.generic_name', 'openfda.manufacturer_name', 'openfda.product_type',
                        'openfda.route', 'openfda.substance_name', 'openfda.rxcui', 'openfda.spl_id', 'openfda.package_ndc']
        
        # Fill NaN values with empty strings
        selected_drugs[metadata_fields] = selected_drugs[metadata_fields].fillna('')
        
        # Define text fields to index
        text_fields = ['description', 'indications_and_usage', 'contraindications', 'warnings', 'adverse_reactions', 'dosage_and_administration']
        
        # Fill NaN values with empty strings and concatenate text fields
        selected_drugs[text_fields] = selected_drugs[text_fields].fillna('')
        selected_drugs['page_content'] = selected_drugs[text_fields].apply(lambda x: ' '.join(x.astype(str)), axis=1)
        
        # Create document loader and load drug documents
        loader = DataFrameLoader(selected_drugs, page_content_column='page_content')
        drug_docs = loader.load()
        
        # Update metadata for each document
        for doc, row in zip(drug_docs, selected_drugs.to_dict(orient='records')):
            metadata = {}
            for field in metadata_fields:
                value = row.get(field)
                if isinstance(value, list):
                    value = ', '.join(str(v) for v in value if pd.notna(v))
                elif pd.isna(value):
                    value = 'Not Available'
                metadata[field] = value
            doc.metadata = metadata
        
        # Split drug documents into chunks
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
        split_drug_docs = text_splitter.split_documents(drug_docs)
        total_docs = len(split_drug_docs)  # Get the total number of split documents
        
        # Index documents in batches
        batch_size = 100
        indexed_count = 0
        for i in range(0, total_docs, batch_size):
            batch_docs = split_drug_docs[i:i+batch_size]
            batch_count = await index_batch(batch_docs, metadata_fields)
            indexed_count += batch_count
            print(f"Indexed {indexed_count} / {total_docs} documents")
        
        remaining = total_docs - indexed_count
        print(f"Indexing completed. Indexed {indexed_count} / {total_docs}, Remaining: {remaining}")
        
        end_time = time.time()  # End timing
        total_time = end_time - start_time
        print(f"Total time taken to index: {total_time:.2f} seconds")
        
        return {"message": "Indexing completed"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))