Spaces:
Runtime error
Runtime error
The app
Browse files
app.py
CHANGED
@@ -1,7 +1,203 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import os
|
4 |
+
from tqdm.auto import tqdm
|
5 |
+
import pinecone
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
import torch
|
8 |
|
9 |
+
from transformers import AutoModel, AutoConfig
|
10 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
11 |
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
12 |
|
13 |
+
# !pip install transformers accelerate
|
14 |
+
# !pip install -qU pinecone-client[grpc] sentence-transformers
|
15 |
+
# !pip install gradio
|
16 |
+
|
17 |
+
class PineconeIndex:
|
18 |
+
|
19 |
+
def __init__(self):
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
|
22 |
+
self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
23 |
+
self.index_name = 'semantic-search-fast-med'
|
24 |
+
self.index = None
|
25 |
+
|
26 |
+
def init_pinecone(self):
|
27 |
+
|
28 |
+
index_name = self.index_name
|
29 |
+
sentence_model = self.sm
|
30 |
+
|
31 |
+
# get api key from app.pinecone.io
|
32 |
+
PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee" # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
|
33 |
+
# find your environment next to the api key in pinecone console
|
34 |
+
PINECONE_ENV = "us-west4-gcp" # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'
|
35 |
+
|
36 |
+
pinecone.init(
|
37 |
+
api_key=PINECONE_API_KEY,
|
38 |
+
environment=PINECONE_ENV
|
39 |
+
)
|
40 |
+
|
41 |
+
pinecone.delete_index(index_name)
|
42 |
+
|
43 |
+
# only create index if it doesn't exist
|
44 |
+
if index_name not in pinecone.list_indexes():
|
45 |
+
pinecone.create_index(
|
46 |
+
name=index_name,
|
47 |
+
dimension=sentence_model.get_sentence_embedding_dimension(),
|
48 |
+
metric='cosine'
|
49 |
+
)
|
50 |
+
|
51 |
+
# now connect to the index
|
52 |
+
self.index = pinecone.GRPCIndex(index_name)
|
53 |
+
return self.index
|
54 |
+
|
55 |
+
def build_index(self):
|
56 |
+
|
57 |
+
if self.index is None:
|
58 |
+
index = self.init_pinecone()
|
59 |
+
else:
|
60 |
+
index = self.index
|
61 |
+
|
62 |
+
if index.describe_index_stats()['total_vector_count']:
|
63 |
+
"Index already built"
|
64 |
+
return
|
65 |
+
|
66 |
+
sentence_model = self.sm
|
67 |
+
|
68 |
+
x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')
|
69 |
+
|
70 |
+
question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
|
71 |
+
'Causes': 'What are causes of {}?',
|
72 |
+
'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
|
73 |
+
'Treatment Options': 'What are the treatment options for {}?',
|
74 |
+
'Prognosis and Complications': 'What are the prognosis and complications?'}
|
75 |
+
context = []
|
76 |
+
disease_list = []
|
77 |
+
|
78 |
+
for i in range(len(x)):
|
79 |
+
disease = x.iloc[i, 0]
|
80 |
+
if disease.strip().lower() in disease_list:
|
81 |
+
continue
|
82 |
+
|
83 |
+
disease_list.append(disease.strip().lower())
|
84 |
+
|
85 |
+
conditions = x.iloc[i, 1:].dropna().index
|
86 |
+
answers = x.iloc[i, 1:].dropna()
|
87 |
+
|
88 |
+
for cond in conditions:
|
89 |
+
context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")
|
90 |
+
|
91 |
+
batch_size = 128
|
92 |
+
for i in tqdm(range(0, len(context), batch_size)):
|
93 |
+
# find end of batch
|
94 |
+
i_end = min(i + batch_size, len(context))
|
95 |
+
# create IDs batch
|
96 |
+
ids = [str(x) for x in range(i, i_end)]
|
97 |
+
# create metadata batch
|
98 |
+
metadatas = [{'text': text} for text in context[i:i_end]]
|
99 |
+
# create embeddings
|
100 |
+
xc = sentence_model.encode(context[i:i_end])
|
101 |
+
# create records list for upsert
|
102 |
+
records = zip(ids, xc, metadatas)
|
103 |
+
# upsert to Pinecone
|
104 |
+
index.upsert(vectors=records)
|
105 |
+
|
106 |
+
# check number of records in the index
|
107 |
+
index.describe_index_stats()
|
108 |
+
|
109 |
+
def search(self, query: str = "medicines for fever"):
|
110 |
+
|
111 |
+
sentence_model = self.sm
|
112 |
+
|
113 |
+
if self.index is None:
|
114 |
+
self.build_index()
|
115 |
+
|
116 |
+
index = self.index
|
117 |
+
|
118 |
+
# create the query vector
|
119 |
+
xq = sentence_model.encode(query).tolist()
|
120 |
+
|
121 |
+
# now query
|
122 |
+
xc = index.query(xq, top_k = 3, include_metadata = True)
|
123 |
+
|
124 |
+
results = []
|
125 |
+
|
126 |
+
for i in xc['matches']:
|
127 |
+
results.append(i['metadata']['text'])
|
128 |
+
|
129 |
+
return results
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
class QAModel():
|
134 |
+
def __init__(self, checkpoint="google/flan-t5-xl"):
|
135 |
+
self.checkpoint = checkpoint
|
136 |
+
self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
|
137 |
+
|
138 |
+
def store_sharded_model(self):
|
139 |
+
tmpdir = self.tmpdir
|
140 |
+
|
141 |
+
checkpoint = self.checkpoint
|
142 |
+
|
143 |
+
if not os.path.exists(tmpdir):
|
144 |
+
os.mkdir(tmpdir)
|
145 |
+
print(f"Directory created - {tmpdir}")
|
146 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
147 |
+
print(f"Model loaded - {checkpoint}")
|
148 |
+
model.save_pretrained(tmpdir, max_shard_size="200MB")
|
149 |
+
|
150 |
+
def load_sharded_model(self):
|
151 |
+
tmpdir = self.tmpdir
|
152 |
+
if not os.path.exists(tmpdir):
|
153 |
+
self.store_sharded_model()
|
154 |
+
|
155 |
+
checkpoint = self.checkpoint
|
156 |
+
|
157 |
+
|
158 |
+
config = AutoConfig.from_pretrained(checkpoint)
|
159 |
+
|
160 |
+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
161 |
+
with init_empty_weights():
|
162 |
+
model = AutoModelForSeq2SeqLM.from_config(config)
|
163 |
+
# model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
|
164 |
+
|
165 |
+
model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
|
166 |
+
return model, tokenizer
|
167 |
+
|
168 |
+
def query_model(self, model, tokenizer, query):
|
169 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
170 |
+
return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)
|
171 |
+
|
172 |
+
PI = PineconeIndex()
|
173 |
+
PI.build_index()
|
174 |
+
qamodel = QAModel()
|
175 |
+
model, tokenizer = qamodel.load_sharded_model()
|
176 |
+
|
177 |
+
def request_answer(query):
|
178 |
+
search_results = PI.search(query)
|
179 |
+
answers = []
|
180 |
+
for r in search_results['matches']:
|
181 |
+
if r['score'] >= 0.45:
|
182 |
+
context = r['metadata']['text']
|
183 |
+
query_to_model = f"""You are doctor who knows cures to diseases. If you don't know the answer, please refrain from providing answers that are not relevant to the context. Please suggest appropriate remedies based on the context provided.\n\nContext: {context}\n\n\nResponse: """
|
184 |
+
answers.append(qamodel.query_model(model, tokenizer, query_to_model))
|
185 |
+
|
186 |
+
if len(answers) == 0:
|
187 |
+
return {'response': "Not enough information to answer the question"}['response']
|
188 |
+
return {'response': '\n'.join(answers)}['response']
|
189 |
+
|
190 |
+
demo = gr.Interface(
|
191 |
+
fn=request_answer,
|
192 |
+
inputs=[
|
193 |
+
gr.components.Textbox(label="User question"),
|
194 |
+
],
|
195 |
+
outputs=["text"],
|
196 |
+
examples=[["Building a demo with Gradio is so easy!",]],
|
197 |
+
cache_examples=False,
|
198 |
+
title="MedQA assistant",
|
199 |
+
#description="MedQA assistant"
|
200 |
+
)
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
print(demo.launch())
|