Spaces:
Running
Running
File size: 6,755 Bytes
12413e3 5d80c84 12413e3 5d80c84 12413e3 5d80c84 12413e3 5d80c84 12413e3 5d80c84 12413e3 5d80c84 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
# Set up a Python virtual environment and activate it:
# python3 -m venv myenv
# source myenv/bin/activate
# Run the FastAPI application with live reload for development:
# uvicorn main:app --reload
# Run the FastAPI application in the background for production:
# nohup uvicorn main:app --host 0.0.0.0 --port 8000 &
# Check which process is using a specific port (8000 in this case):
# lsof -i :8000
# Terminate a process using a specific port by its PID (e.g., PID 2540):
# kill -9 2540
# Example of a POST request from Postman or any other HTTP client:
# This request indexes data from a specified URL:
# You can call this endpoint by sending a POST request to:
# http://your_server_url/index_fda_drugs?url=https://download.open.fda.gov/drug/label/drug-label-0001-of-0012.json.zip
# where the URL is passed as a query parameter.
# Examples for testing the endpoint locally using curl:
# Local URL testing with curl:
# curl -X POST "http://127.0.0.1:8000/index_fda_drugs?url=https://download.open.fda.gov/drug/label/drug-label-0001-of-0012.json.zip"
import asyncio
import time
from qdrant_client import AsyncQdrantClient
from qdrant_client.http import models
import pandas as pd
import zipfile
import io
import requests
import json
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
from langchain_community.document_loaders import DataFrameLoader
import uuid
import os
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Query
app = FastAPI()
# Load environment variables from a .env file
load_dotenv()
# Set up Qdrant client and embedding model
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
QDRANT_CLUSTER_URL = os.environ.get("QDRANT_CLUSTER_URL")
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
client = AsyncQdrantClient(QDRANT_CLUSTER_URL, api_key=QDRANT_API_KEY)
async def create_collection():
try:
collection_info = await client.get_collection(collection_name="fda_drugs")
print(f"Collection 'fda_drugs' already exists.")
except Exception as e:
print(f"Collection 'fda_drugs' does not exist. Creating...")
collection_info = await client.create_collection(
collection_name="fda_drugs",
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE)
)
print(f"Collection 'fda_drugs' created: {collection_info}")
async def index_batch(batch_docs, metadata_fields):
points = []
for doc in batch_docs:
try:
vector = embedding_model.embed_query(doc.page_content)
if vector is not None:
payload = {field: doc.metadata.get(field, '') for field in metadata_fields}
payload["page_content"] = doc.page_content
points.append(models.PointStruct(
id=str(uuid.uuid4()),
payload=payload,
vector=vector,
))
except Exception as e:
print(f"Failed to index document: {e}")
if points:
try:
response = await client.upsert(
collection_name="fda_drugs",
points=points,
)
return len(points)
except Exception as e:
print(f"Failed to upsert batch: {e}")
return 0
@app.post("/index_fda_drugs")
async def index_fda_drugs(url: str = Query(..., description="URL of the ZIP file to index")):
try:
start_time = time.time() # Start timing
# Create or recreate the collection
await create_collection()
# Download and load data
response = requests.get(url)
zip_file = zipfile.ZipFile(io.BytesIO(response.content))
json_file = zip_file.open(zip_file.namelist()[0])
data = json.load(json_file)
df = pd.json_normalize(data['results'])
selected_drugs = df
# Define metadata fields to include
metadata_fields = ['openfda.brand_name', 'openfda.generic_name', 'openfda.manufacturer_name', 'openfda.product_type',
'openfda.route', 'openfda.substance_name', 'openfda.rxcui', 'openfda.spl_id', 'openfda.package_ndc']
# Fill NaN values with empty strings
selected_drugs[metadata_fields] = selected_drugs[metadata_fields].fillna('')
# Define text fields to index
text_fields = ['description', 'indications_and_usage', 'contraindications', 'warnings', 'adverse_reactions', 'dosage_and_administration']
# Fill NaN values with empty strings and concatenate text fields
selected_drugs[text_fields] = selected_drugs[text_fields].fillna('')
selected_drugs['page_content'] = selected_drugs[text_fields].apply(lambda x: ' '.join(x.astype(str)), axis=1)
# Create document loader and load drug documents
loader = DataFrameLoader(selected_drugs, page_content_column='page_content')
drug_docs = loader.load()
# Update metadata for each document
for doc, row in zip(drug_docs, selected_drugs.to_dict(orient='records')):
metadata = {}
for field in metadata_fields:
value = row.get(field)
if isinstance(value, list):
value = ', '.join(str(v) for v in value if pd.notna(v))
elif pd.isna(value):
value = 'Not Available'
metadata[field] = value
doc.metadata = metadata
# Split drug documents into chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_drug_docs = text_splitter.split_documents(drug_docs)
total_docs = len(split_drug_docs) # Get the total number of split documents
# Index documents in batches
batch_size = 100
indexed_count = 0
for i in range(0, total_docs, batch_size):
batch_docs = split_drug_docs[i:i+batch_size]
batch_count = await index_batch(batch_docs, metadata_fields)
indexed_count += batch_count
print(f"Indexed {indexed_count} / {total_docs} documents")
remaining = total_docs - indexed_count
print(f"Indexing completed. Indexed {indexed_count} / {total_docs}, Remaining: {remaining}")
end_time = time.time() # End timing
total_time = end_time - start_time
print(f"Total time taken to index: {total_time:.2f} seconds")
return {"message": "Indexing completed"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) |