File size: 7,430 Bytes
1dd74c6
d99f88f
4e44d93
 
1dd74c6
 
 
 
 
4e44d93
1dd74c6
4e44d93
 
1dd74c6
4e44d93
 
 
1dd74c6
4e44d93
1dd74c6
4e44d93
 
1dd74c6
4e44d93
 
 
 
1dd74c6
4e44d93
 
 
 
1dd74c6
4e44d93
1dd74c6
4e44d93
 
 
 
 
 
 
1dd74c6
4e44d93
 
 
1dd74c6
4e44d93
1dd74c6
4e44d93
 
 
 
1dd74c6
4e44d93
 
 
1dd74c6
4e44d93
1dd74c6
4e44d93
1dd74c6
4e44d93
 
 
 
 
 
 
1dd74c6
4e44d93
 
 
 
1dd74c6
4e44d93
1dd74c6
4e44d93
 
1dd74c6
4e44d93
 
1dd74c6
4e44d93
 
 
 
 
 
 
 
 
 
 
 
 
 
1dd74c6
4e44d93
 
1dd74c6
4e44d93
1dd74c6
4e44d93
1dd74c6
4e44d93
 
1dd74c6
4e44d93
1dd74c6
4e44d93
 
1dd74c6
4e44d93
 
1dd74c6
4e44d93
1dd74c6
 
4e44d93
 
 
 
1dd74c6
4e44d93
 
1dd74c6
4e44d93
1dd74c6
4e44d93
 
 
 
 
 
 
 
 
 
 
1dd74c6
4e44d93
1dd74c6
 
4e44d93
1dd74c6
4e44d93
 
 
 
1dd74c6
4e44d93
 
1dd74c6
4e44d93
 
 
1dd74c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b605e1
 
e154f23
6b605e1
1dd74c6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import gradio as gr
import os
from pinecone_integration import PineconeIndex
from qa_model import QAModel

# !pip install transformers accelerate
# !pip install -qU pinecone-client[grpc] sentence-transformers
# !pip install gradio

# class PineconeIndex:

#     def __init__(self):
#         device = 'cuda' if torch.cuda.is_available() else 'cpu'

#         self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
#         self.index_name = 'semantic-search-fast-med'
#         self.index = None

#     def init_pinecone(self):

#         index_name = self.index_name
#         sentence_model = self.sm

#         # get api key from app.pinecone.io
#         PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee"  # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
#         # find your environment next to the api key in pinecone console
#         PINECONE_ENV = "us-west4-gcp"  # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'

#         pinecone.init(
#             api_key=PINECONE_API_KEY,
#             environment=PINECONE_ENV
#         )

# #         pinecone.delete_index(index_name)

#         # only create index if it doesn't exist
#         if index_name not in pinecone.list_indexes():
#             pinecone.create_index(
#                 name=index_name,
#                 dimension=sentence_model.get_sentence_embedding_dimension(),
#                 metric='cosine'
#             )

#         # now connect to the index
#         self.index = pinecone.GRPCIndex(index_name)
#         return self.index

#     def build_index(self):

#         if self.index is None:
#             index = self.init_pinecone()
#         else:
#             index = self.index

#         if index.describe_index_stats()['total_vector_count']:
#             "Index already built"
#             return

#         sentence_model = self.sm

#         x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')

#         question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
#                          'Causes': 'What are causes of {}?',
#                          'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
#                          'Treatment Options': 'What are the treatment options for {}?',
#                          'Prognosis and Complications': 'What are the prognosis and complications?'}
#         context = []
#         disease_list = []

#         for i in range(len(x)):
#             disease = x.iloc[i, 0]
#             if disease.strip().lower() in disease_list:
#                 continue

#             disease_list.append(disease.strip().lower())

#             conditions = x.iloc[i, 1:].dropna().index
#             answers = x.iloc[i, 1:].dropna()

#             for cond in conditions:
#                 context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")

#         batch_size = 128
#         for i in tqdm(range(0, len(context), batch_size)):
#             # find end of batch
#             i_end = min(i + batch_size, len(context))
#             # create IDs batch
#             ids = [str(x) for x in range(i, i_end)]
#             # create metadata batch
#             metadatas = [{'text': text} for text in context[i:i_end]]
#             # create embeddings
#             xc = sentence_model.encode(context[i:i_end])
#             # create records list for upsert
#             records = zip(ids, xc, metadatas)
#             # upsert to Pinecone
#             index.upsert(vectors=records)

#         # check number of records in the index
#         index.describe_index_stats()

#     def search(self, query: str = "medicines for fever"):

#         sentence_model = self.sm

#         if self.index is None:
#             self.build_index()

#         index = self.index

#         # create the query vector
#         xq = sentence_model.encode(query).tolist()

#         # now query
#         xc = index.query(xq, top_k = 3, include_metadata = True)
        
#         return xc


# class QAModel():
#     def __init__(self, checkpoint="google/flan-t5-xl"):
#         self.checkpoint = checkpoint
#         self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"

#     def store_sharded_model(self):
#         tmpdir = self.tmpdir
        
#         checkpoint = self.checkpoint
        
#         if not os.path.exists(tmpdir):
#             os.mkdir(tmpdir)
#             print(f"Directory created - {tmpdir}")
#             model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
#             print(f"Model loaded - {checkpoint}")
#             model.save_pretrained(tmpdir, max_shard_size="200MB")

#     def load_sharded_model(self):
#         tmpdir = self.tmpdir
#         if not os.path.exists(tmpdir):
#             self.store_sharded_model()
            
#         checkpoint = self.checkpoint
        

#         config = AutoConfig.from_pretrained(checkpoint)

#         tokenizer = AutoTokenizer.from_pretrained(checkpoint)
#         with init_empty_weights():
#             model = AutoModelForSeq2SeqLM.from_config(config)
#             # model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)

#         model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
#         return model, tokenizer

#     def query_model(self, model, tokenizer, query):
#         device = 'cuda' if torch.cuda.is_available() else 'cpu'
#         return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]

PI = PineconeIndex()
PI.build_index()
qamodel = QAModel()
model, tokenizer = qamodel.load_sharded_model()

def request_answer(query):
    search_results = PI.search(query)
    answers = []
    # print(search_results)
    for r in search_results['matches']:
        if r['score'] >= 0.45:
            tokenized_context = tokenizer(r['metadata']['text'])
#             query_to_model = f"""You are doctor who knows cures to diseases. If you don't know the answer, please refrain from providing answers that are not relevant to the context. Please suggest appropriate remedies based on the context provided.\n\nContext: {context}\n\n\nResponse: """
            query_to_model = """You are doctor who knows cures to diseases. If you don't know, say you don't know. Please respond appropriately based on the context provided.\n\nContext: {}\n\n\nResponse: """
            for ind in range(0, len(tokenized_context['input_ids']), 512-42):                        
                decoded_tokens_for_context = tokenizer.batch_decode([tokenized_context['input_ids'][ind:ind+470]], skip_special_tokens=True)
                response = qamodel.query_model(model, tokenizer, query_to_model.format(decoded_tokens_for_context[0]))
                
                if not "don't know" in response:
                    answers.append(response)

    if len(answers) == 0:
        return 'Not enough information to answer the question'
    return '\n'.join(answers)


demo = gr.Interface(
    fn=request_answer,
    inputs=[
        gr.components.Textbox(label="User question(Response may take up to 2 mins because of hardware limitation)"),
    ],
    outputs=[
        gr.components.Textbox(label="Output (The answer is meant as a reference and not actual advice)"),
    ],
    cache_examples=True,
    title="MedQA assistant",
    #description="MedQA assistant"
)

demo.launch()