xpsychted commited on
Commit
d99f88f
1 Parent(s): 5399745
Files changed (1) hide show
  1. app.py +200 -4
app.py CHANGED
@@ -1,7 +1,203 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import os
4
+ from tqdm.auto import tqdm
5
+ import pinecone
6
+ from sentence_transformers import SentenceTransformer
7
+ import torch
8
 
9
+ from transformers import AutoModel, AutoConfig
10
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
11
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
12
 
13
+ # !pip install transformers accelerate
14
+ # !pip install -qU pinecone-client[grpc] sentence-transformers
15
+ # !pip install gradio
16
+
17
+ class PineconeIndex:
18
+
19
+ def __init__(self):
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+
22
+ self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
23
+ self.index_name = 'semantic-search-fast-med'
24
+ self.index = None
25
+
26
+ def init_pinecone(self):
27
+
28
+ index_name = self.index_name
29
+ sentence_model = self.sm
30
+
31
+ # get api key from app.pinecone.io
32
+ PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee" # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
33
+ # find your environment next to the api key in pinecone console
34
+ PINECONE_ENV = "us-west4-gcp" # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'
35
+
36
+ pinecone.init(
37
+ api_key=PINECONE_API_KEY,
38
+ environment=PINECONE_ENV
39
+ )
40
+
41
+ pinecone.delete_index(index_name)
42
+
43
+ # only create index if it doesn't exist
44
+ if index_name not in pinecone.list_indexes():
45
+ pinecone.create_index(
46
+ name=index_name,
47
+ dimension=sentence_model.get_sentence_embedding_dimension(),
48
+ metric='cosine'
49
+ )
50
+
51
+ # now connect to the index
52
+ self.index = pinecone.GRPCIndex(index_name)
53
+ return self.index
54
+
55
+ def build_index(self):
56
+
57
+ if self.index is None:
58
+ index = self.init_pinecone()
59
+ else:
60
+ index = self.index
61
+
62
+ if index.describe_index_stats()['total_vector_count']:
63
+ "Index already built"
64
+ return
65
+
66
+ sentence_model = self.sm
67
+
68
+ x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')
69
+
70
+ question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
71
+ 'Causes': 'What are causes of {}?',
72
+ 'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
73
+ 'Treatment Options': 'What are the treatment options for {}?',
74
+ 'Prognosis and Complications': 'What are the prognosis and complications?'}
75
+ context = []
76
+ disease_list = []
77
+
78
+ for i in range(len(x)):
79
+ disease = x.iloc[i, 0]
80
+ if disease.strip().lower() in disease_list:
81
+ continue
82
+
83
+ disease_list.append(disease.strip().lower())
84
+
85
+ conditions = x.iloc[i, 1:].dropna().index
86
+ answers = x.iloc[i, 1:].dropna()
87
+
88
+ for cond in conditions:
89
+ context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")
90
+
91
+ batch_size = 128
92
+ for i in tqdm(range(0, len(context), batch_size)):
93
+ # find end of batch
94
+ i_end = min(i + batch_size, len(context))
95
+ # create IDs batch
96
+ ids = [str(x) for x in range(i, i_end)]
97
+ # create metadata batch
98
+ metadatas = [{'text': text} for text in context[i:i_end]]
99
+ # create embeddings
100
+ xc = sentence_model.encode(context[i:i_end])
101
+ # create records list for upsert
102
+ records = zip(ids, xc, metadatas)
103
+ # upsert to Pinecone
104
+ index.upsert(vectors=records)
105
+
106
+ # check number of records in the index
107
+ index.describe_index_stats()
108
+
109
+ def search(self, query: str = "medicines for fever"):
110
+
111
+ sentence_model = self.sm
112
+
113
+ if self.index is None:
114
+ self.build_index()
115
+
116
+ index = self.index
117
+
118
+ # create the query vector
119
+ xq = sentence_model.encode(query).tolist()
120
+
121
+ # now query
122
+ xc = index.query(xq, top_k = 3, include_metadata = True)
123
+
124
+ results = []
125
+
126
+ for i in xc['matches']:
127
+ results.append(i['metadata']['text'])
128
+
129
+ return results
130
+
131
+
132
+
133
+ class QAModel():
134
+ def __init__(self, checkpoint="google/flan-t5-xl"):
135
+ self.checkpoint = checkpoint
136
+ self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
137
+
138
+ def store_sharded_model(self):
139
+ tmpdir = self.tmpdir
140
+
141
+ checkpoint = self.checkpoint
142
+
143
+ if not os.path.exists(tmpdir):
144
+ os.mkdir(tmpdir)
145
+ print(f"Directory created - {tmpdir}")
146
+ model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
147
+ print(f"Model loaded - {checkpoint}")
148
+ model.save_pretrained(tmpdir, max_shard_size="200MB")
149
+
150
+ def load_sharded_model(self):
151
+ tmpdir = self.tmpdir
152
+ if not os.path.exists(tmpdir):
153
+ self.store_sharded_model()
154
+
155
+ checkpoint = self.checkpoint
156
+
157
+
158
+ config = AutoConfig.from_pretrained(checkpoint)
159
+
160
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
161
+ with init_empty_weights():
162
+ model = AutoModelForSeq2SeqLM.from_config(config)
163
+ # model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
164
+
165
+ model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
166
+ return model, tokenizer
167
+
168
+ def query_model(self, model, tokenizer, query):
169
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
170
+ return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)
171
+
172
+ PI = PineconeIndex()
173
+ PI.build_index()
174
+ qamodel = QAModel()
175
+ model, tokenizer = qamodel.load_sharded_model()
176
+
177
+ def request_answer(query):
178
+ search_results = PI.search(query)
179
+ answers = []
180
+ for r in search_results['matches']:
181
+ if r['score'] >= 0.45:
182
+ context = r['metadata']['text']
183
+ query_to_model = f"""You are doctor who knows cures to diseases. If you don't know the answer, please refrain from providing answers that are not relevant to the context. Please suggest appropriate remedies based on the context provided.\n\nContext: {context}\n\n\nResponse: """
184
+ answers.append(qamodel.query_model(model, tokenizer, query_to_model))
185
+
186
+ if len(answers) == 0:
187
+ return {'response': "Not enough information to answer the question"}['response']
188
+ return {'response': '\n'.join(answers)}['response']
189
+
190
+ demo = gr.Interface(
191
+ fn=request_answer,
192
+ inputs=[
193
+ gr.components.Textbox(label="User question"),
194
+ ],
195
+ outputs=["text"],
196
+ examples=[["Building a demo with Gradio is so easy!",]],
197
+ cache_examples=False,
198
+ title="MedQA assistant",
199
+ #description="MedQA assistant"
200
+ )
201
+
202
+ if __name__ == "__main__":
203
+ print(demo.launch())