File size: 7,450 Bytes
1dd74c6
 
d99f88f
1dd74c6
 
 
 
1834080
1dd74c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67d2eac
1dd74c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67d2eac
1dd74c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b605e1
 
e154f23
6b605e1
1dd74c6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import gradio as gr
import pandas as pd
import os
from tqdm.auto import tqdm
import pinecone
from sentence_transformers import SentenceTransformer
import torch

from transformers import AutoModel, AutoConfig
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

# !pip install transformers accelerate
# !pip install -qU pinecone-client[grpc] sentence-transformers
# !pip install gradio

class PineconeIndex:

    def __init__(self):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
        self.index_name = 'semantic-search-fast-med'
        self.index = None

    def init_pinecone(self):

        index_name = self.index_name
        sentence_model = self.sm

        # get api key from app.pinecone.io
        PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee"  # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
        # find your environment next to the api key in pinecone console
        PINECONE_ENV = "us-west4-gcp"  # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'

        pinecone.init(
            api_key=PINECONE_API_KEY,
            environment=PINECONE_ENV
        )

#         pinecone.delete_index(index_name)

        # only create index if it doesn't exist
        if index_name not in pinecone.list_indexes():
            pinecone.create_index(
                name=index_name,
                dimension=sentence_model.get_sentence_embedding_dimension(),
                metric='cosine'
            )

        # now connect to the index
        self.index = pinecone.GRPCIndex(index_name)
        return self.index

    def build_index(self):

        if self.index is None:
            index = self.init_pinecone()
        else:
            index = self.index

        if index.describe_index_stats()['total_vector_count']:
            "Index already built"
            return

        sentence_model = self.sm

        x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')

        question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
                         'Causes': 'What are causes of {}?',
                         'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
                         'Treatment Options': 'What are the treatment options for {}?',
                         'Prognosis and Complications': 'What are the prognosis and complications?'}
        context = []
        disease_list = []

        for i in range(len(x)):
            disease = x.iloc[i, 0]
            if disease.strip().lower() in disease_list:
                continue

            disease_list.append(disease.strip().lower())

            conditions = x.iloc[i, 1:].dropna().index
            answers = x.iloc[i, 1:].dropna()

            for cond in conditions:
                context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")

        batch_size = 128
        for i in tqdm(range(0, len(context), batch_size)):
            # find end of batch
            i_end = min(i + batch_size, len(context))
            # create IDs batch
            ids = [str(x) for x in range(i, i_end)]
            # create metadata batch
            metadatas = [{'text': text} for text in context[i:i_end]]
            # create embeddings
            xc = sentence_model.encode(context[i:i_end])
            # create records list for upsert
            records = zip(ids, xc, metadatas)
            # upsert to Pinecone
            index.upsert(vectors=records)

        # check number of records in the index
        index.describe_index_stats()

    def search(self, query: str = "medicines for fever"):

        sentence_model = self.sm

        if self.index is None:
            self.build_index()

        index = self.index

        # create the query vector
        xq = sentence_model.encode(query).tolist()

        # now query
        xc = index.query(xq, top_k = 3, include_metadata = True)
        
        return xc


class QAModel():
    def __init__(self, checkpoint="google/flan-t5-xl"):
        self.checkpoint = checkpoint
        self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"

    def store_sharded_model(self):
        tmpdir = self.tmpdir
        
        checkpoint = self.checkpoint
        
        if not os.path.exists(tmpdir):
            os.mkdir(tmpdir)
            print(f"Directory created - {tmpdir}")
            model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
            print(f"Model loaded - {checkpoint}")
            model.save_pretrained(tmpdir, max_shard_size="200MB")

    def load_sharded_model(self):
        tmpdir = self.tmpdir
        if not os.path.exists(tmpdir):
            self.store_sharded_model()
            
        checkpoint = self.checkpoint
        

        config = AutoConfig.from_pretrained(checkpoint)

        tokenizer = AutoTokenizer.from_pretrained(checkpoint)
        with init_empty_weights():
            model = AutoModelForSeq2SeqLM.from_config(config)
            # model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)

        model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
        return model, tokenizer

    def query_model(self, model, tokenizer, query):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]

PI = PineconeIndex()
PI.build_index()
qamodel = QAModel()
model, tokenizer = qamodel.load_sharded_model()

def request_answer(query):
    search_results = PI.search(query)
    answers = []
    # print(search_results)
    for r in search_results['matches']:
        if r['score'] >= 0.45:
            tokenized_context = tokenizer(r['metadata']['text'])
#             query_to_model = f"""You are doctor who knows cures to diseases. If you don't know the answer, please refrain from providing answers that are not relevant to the context. Please suggest appropriate remedies based on the context provided.\n\nContext: {context}\n\n\nResponse: """
            query_to_model = """You are doctor who knows cures to diseases. If you don't know, say you don't know. Please respond appropriately based on the context provided.\n\nContext: {}\n\n\nResponse: """
            for ind in range(0, len(tokenized_context['input_ids']), 512-42):                        
                decoded_tokens_for_context = tokenizer.batch_decode([tokenized_context['input_ids'][ind:ind+470]], skip_special_tokens=True)
                response = qamodel.query_model(model, tokenizer, query_to_model.format(decoded_tokens_for_context[0]))
                
                if not "don't know" in response:
                    answers.append(response)

    if len(answers) == 0:
        return 'Not enough information to answer the question'
    return '\n'.join(answers)


demo = gr.Interface(
    fn=request_answer,
    inputs=[
        gr.components.Textbox(label="User question(Response may take up to 2 mins because of hardware limitation)"),
    ],
    outputs=[
        gr.components.Textbox(label="Output (The answer is meant as a reference and not actual advice)"),
    ],
    cache_examples=True,
    title="MedQA assistant",
    #description="MedQA assistant"
)

demo.launch()