import os import re import faiss import torch import numpy as np import pandas as pd import zipfile from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoModel, AutoTokenizer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity # 🌍 Set Hugging Face Cache Directory (optional) os.environ["HF_HOME"] = "/app/huggingface" os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" # ✅ Use /tmp to avoid permission issues model_path = "/tmp/my_model" if not os.path.exists(model_path): with zipfile.ZipFile("my_model.zip", "r") as zip_ref: zip_ref.extractall(model_path) print("✅ Model unzipped!") # 🤗 Load tokenizer and model from local directory retrieval_tokenizer = AutoTokenizer.from_pretrained(model_path) retrieval_model = AutoModel.from_pretrained(model_path) # ✅ Start FastAPI app app = FastAPI() # 📄 Load Clinical Trials CSV csv_path = "ctg-studies-obesity.csv" if os.path.exists(csv_path): df_trials = pd.read_csv(csv_path) print("✅ CSV File Loaded Successfully!") else: raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}") # 🏷️ Rename Columns for Consistency df_trials.rename(columns={ "NCT Number": "NCTID", "Interventions": "Intervention", "Phases": "Phase", "Study Status": "Status", "Completion Date": "Completion Date", "Study Results": "Has Results", "Sponsor": "Sponsor" }, inplace=True) # 📦 Load FAISS Index dimension = 768 faiss_index_path = "clinical_trials.index" if os.path.exists(faiss_index_path): index = faiss.read_index(faiss_index_path) print("✅ FAISS Index Loaded!") else: index = faiss.IndexFlatL2(dimension) print("⚠ FAISS Index Not Found. Using Empty Index.") # 📦 Request Models class QueryRequest(BaseModel): text: str class StudyText(BaseModel): text: str # 🧠 Generate Embedding def generate_embedding(text): inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) with torch.no_grad(): outputs = retrieval_model(**inputs) return outputs.last_hidden_state[:, 0, :].numpy() # CLS token # 📑 Get Trial Details def get_trial_info(nct_id): trial_info = df_trials[df_trials["NCTID"] == nct_id].fillna("N/A").to_dict(orient="records") return trial_info[0] if trial_info else None # ✂️ Extract Summary def extract_summary(text, max_sentences=2): if not isinstance(text, str) or not text.strip(): return "No summary available." sentences = re.split(r'(?<=[.!?])\s+', text) if len(sentences) <= max_sentences: return text vectorizer = TfidfVectorizer(stop_words="english") tfidf_matrix = vectorizer.fit_transform(sentences) similarity_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix) scores = similarity_matrix.sum(axis=1) ranked_sentences = sorted(zip(scores, sentences), reverse=True)[:max_sentences] return " ".join([s[1] for s in sorted(ranked_sentences, key=lambda x: sentences.index(x[1]))]) # 🕒 Extract Timeline def extract_study_timeline(text: str): def extract(patterns): for pattern in patterns: match = re.search(pattern, text, re.IGNORECASE) if match: return int(match.group(1)) return None return { "Screening": extract([r'Screening.*?(\d+)\s*weeks?']), "Treatment": extract([r'Treatment.*?(\d+)\s*weeks?']), "Follow-Up": extract([r'Follow[-\s]*up.*?(\d+)\s*weeks?']) } # 🚀 API Routes @app.post("/retrieve") async def retrieve_trial(request: QueryRequest): query_vector = generate_embedding(request.text) total_trials = index.ntotal distances, indices = index.search(query_vector, total_trials) matched_trials = [] for idx in indices[0]: if idx < len(df_trials): nct_id = df_trials.iloc[idx]["NCTID"] trial_data = get_trial_info(nct_id) if trial_data: summary_text = trial_data.get("Brief Summary") or trial_data.get("Description") or trial_data.get("Detailed Description", "") matched_trials.append({ "NCTID": trial_data["NCTID"], "Intervention": trial_data.get("Intervention", "N/A"), "Phase": trial_data.get("Phase", "N/A"), "Status": trial_data.get("Status", "N/A"), "Completion Date": trial_data.get("Completion Date", "N/A"), "Has Results": trial_data.get("Has Results", "N/A"), "Sponsor": trial_data.get("Sponsor", "N/A"), "Summary": extract_summary(summary_text) }) return {"matched_trials": matched_trials} @app.post("/extract-timeline/") async def extract_timeline(request: StudyText): return extract_study_timeline(request.text) @app.get("/trial/{nct_id}") async def get_trial_details(nct_id: str): trial_data = get_trial_info(nct_id) return {"trial_details": trial_data} if trial_data else {"error": "Trial not found"} @app.get("/") async def root(): return {"message": "🌟 TrialGPT API is Running with Local Model & Timeline Extraction!"}