MedQA-Assistant / app.py
xpsychted's picture
Update app.py
e154f23
raw
history blame
No virus
7.45 kB
import gradio as gr
import pandas as pd
import os
from tqdm.auto import tqdm
import pinecone
from sentence_transformers import SentenceTransformer
import torch
from transformers import AutoModel, AutoConfig
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
# !pip install transformers accelerate
# !pip install -qU pinecone-client[grpc] sentence-transformers
# !pip install gradio
class PineconeIndex:
def __init__(self):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
self.index_name = 'semantic-search-fast-med'
self.index = None
def init_pinecone(self):
index_name = self.index_name
sentence_model = self.sm
# get api key from app.pinecone.io
PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee" # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
# find your environment next to the api key in pinecone console
PINECONE_ENV = "us-west4-gcp" # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'
pinecone.init(
api_key=PINECONE_API_KEY,
environment=PINECONE_ENV
)
# pinecone.delete_index(index_name)
# only create index if it doesn't exist
if index_name not in pinecone.list_indexes():
pinecone.create_index(
name=index_name,
dimension=sentence_model.get_sentence_embedding_dimension(),
metric='cosine'
)
# now connect to the index
self.index = pinecone.GRPCIndex(index_name)
return self.index
def build_index(self):
if self.index is None:
index = self.init_pinecone()
else:
index = self.index
if index.describe_index_stats()['total_vector_count']:
"Index already built"
return
sentence_model = self.sm
x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')
question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
'Causes': 'What are causes of {}?',
'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
'Treatment Options': 'What are the treatment options for {}?',
'Prognosis and Complications': 'What are the prognosis and complications?'}
context = []
disease_list = []
for i in range(len(x)):
disease = x.iloc[i, 0]
if disease.strip().lower() in disease_list:
continue
disease_list.append(disease.strip().lower())
conditions = x.iloc[i, 1:].dropna().index
answers = x.iloc[i, 1:].dropna()
for cond in conditions:
context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")
batch_size = 128
for i in tqdm(range(0, len(context), batch_size)):
# find end of batch
i_end = min(i + batch_size, len(context))
# create IDs batch
ids = [str(x) for x in range(i, i_end)]
# create metadata batch
metadatas = [{'text': text} for text in context[i:i_end]]
# create embeddings
xc = sentence_model.encode(context[i:i_end])
# create records list for upsert
records = zip(ids, xc, metadatas)
# upsert to Pinecone
index.upsert(vectors=records)
# check number of records in the index
index.describe_index_stats()
def search(self, query: str = "medicines for fever"):
sentence_model = self.sm
if self.index is None:
self.build_index()
index = self.index
# create the query vector
xq = sentence_model.encode(query).tolist()
# now query
xc = index.query(xq, top_k = 3, include_metadata = True)
return xc
class QAModel():
def __init__(self, checkpoint="google/flan-t5-xl"):
self.checkpoint = checkpoint
self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
def store_sharded_model(self):
tmpdir = self.tmpdir
checkpoint = self.checkpoint
if not os.path.exists(tmpdir):
os.mkdir(tmpdir)
print(f"Directory created - {tmpdir}")
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
print(f"Model loaded - {checkpoint}")
model.save_pretrained(tmpdir, max_shard_size="200MB")
def load_sharded_model(self):
tmpdir = self.tmpdir
if not os.path.exists(tmpdir):
self.store_sharded_model()
checkpoint = self.checkpoint
config = AutoConfig.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
with init_empty_weights():
model = AutoModelForSeq2SeqLM.from_config(config)
# model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
return model, tokenizer
def query_model(self, model, tokenizer, query):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]
PI = PineconeIndex()
PI.build_index()
qamodel = QAModel()
model, tokenizer = qamodel.load_sharded_model()
def request_answer(query):
search_results = PI.search(query)
answers = []
# print(search_results)
for r in search_results['matches']:
if r['score'] >= 0.45:
tokenized_context = tokenizer(r['metadata']['text'])
# query_to_model = f"""You are doctor who knows cures to diseases. If you don't know the answer, please refrain from providing answers that are not relevant to the context. Please suggest appropriate remedies based on the context provided.\n\nContext: {context}\n\n\nResponse: """
query_to_model = """You are doctor who knows cures to diseases. If you don't know, say you don't know. Please respond appropriately based on the context provided.\n\nContext: {}\n\n\nResponse: """
for ind in range(0, len(tokenized_context['input_ids']), 512-42):
decoded_tokens_for_context = tokenizer.batch_decode([tokenized_context['input_ids'][ind:ind+470]], skip_special_tokens=True)
response = qamodel.query_model(model, tokenizer, query_to_model.format(decoded_tokens_for_context[0]))
if not "don't know" in response:
answers.append(response)
if len(answers) == 0:
return 'Not enough information to answer the question'
return '\n'.join(answers)
demo = gr.Interface(
fn=request_answer,
inputs=[
gr.components.Textbox(label="User question(Response may take up to 2 mins because of hardware limitation)"),
],
outputs=[
gr.components.Textbox(label="Output (The answer is meant as a reference and not actual advice)"),
],
cache_examples=True,
title="MedQA assistant",
#description="MedQA assistant"
)
demo.launch()