tannu038 commited on
Commit
574ee0a
Β·
verified Β·
1 Parent(s): 1291e4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -4,16 +4,29 @@ import faiss
4
  import torch
5
  import numpy as np
6
  import pandas as pd
 
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
  from transformers import AutoModel, AutoTokenizer
10
  from sklearn.feature_extraction.text import TfidfVectorizer
11
  from sklearn.metrics.pairwise import cosine_similarity
12
 
13
- # 🌍 Set Hugging Face Cache Directory
14
  os.environ["HF_HOME"] = "/app/huggingface"
15
  os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  app = FastAPI()
18
 
19
  # πŸ“„ Load Clinical Trials CSV
@@ -45,11 +58,6 @@ else:
45
  index = faiss.IndexFlatL2(dimension)
46
  print("⚠ FAISS Index Not Found. Using Empty Index.")
47
 
48
- # πŸ€— Load Public Model from Hugging Face
49
- retrieval_model_name = "distilbert-base-uncased"
50
- retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
51
- retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
52
-
53
  # πŸ“¦ Request Models
54
  class QueryRequest(BaseModel):
55
  text: str
@@ -137,4 +145,4 @@ async def get_trial_details(nct_id: str):
137
 
138
  @app.get("/")
139
  async def root():
140
- return {"message": "🌟 TrialGPT API is Running with Public Model & Timeline Extraction!"}
 
4
  import torch
5
  import numpy as np
6
  import pandas as pd
7
+ import zipfile
8
  from fastapi import FastAPI
9
  from pydantic import BaseModel
10
  from transformers import AutoModel, AutoTokenizer
11
  from sklearn.feature_extraction.text import TfidfVectorizer
12
  from sklearn.metrics.pairwise import cosine_similarity
13
 
14
+ # 🌍 Set Hugging Face Cache Directory (if needed)
15
  os.environ["HF_HOME"] = "/app/huggingface"
16
  os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
17
 
18
+ # βœ… Unzip model if not already unzipped
19
+ model_path = "my_model"
20
+ if not os.path.exists(model_path):
21
+ with zipfile.ZipFile("my_model.zip", "r") as zip_ref:
22
+ zip_ref.extractall(model_path)
23
+ print("βœ… Model unzipped!")
24
+
25
+ # πŸ€— Load tokenizer and model from local directory
26
+ retrieval_tokenizer = AutoTokenizer.from_pretrained(model_path)
27
+ retrieval_model = AutoModel.from_pretrained(model_path)
28
+
29
+ # βœ… Start FastAPI
30
  app = FastAPI()
31
 
32
  # πŸ“„ Load Clinical Trials CSV
 
58
  index = faiss.IndexFlatL2(dimension)
59
  print("⚠ FAISS Index Not Found. Using Empty Index.")
60
 
 
 
 
 
 
61
  # πŸ“¦ Request Models
62
  class QueryRequest(BaseModel):
63
  text: str
 
145
 
146
  @app.get("/")
147
  async def root():
148
+ return {"message": "🌟 TrialGPT API is Running with Local Model & Timeline Extraction!"}