Spaces:
Runtime error
Runtime error
File size: 7,430 Bytes
1dd74c6 d99f88f 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 4e44d93 1dd74c6 6b605e1 e154f23 6b605e1 1dd74c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import gradio as gr
import os
from pinecone_integration import PineconeIndex
from qa_model import QAModel
# !pip install transformers accelerate
# !pip install -qU pinecone-client[grpc] sentence-transformers
# !pip install gradio
# class PineconeIndex:
# def __init__(self):
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# self.sm = SentenceTransformer('all-MiniLM-L6-v2', device=device)
# self.index_name = 'semantic-search-fast-med'
# self.index = None
# def init_pinecone(self):
# index_name = self.index_name
# sentence_model = self.sm
# # get api key from app.pinecone.io
# PINECONE_API_KEY = "b97d5759-dd39-428b-a1fd-ed30f3ba74ee" # os.environ.get('PINECONE_API_KEY') or 'PINECONE_API_KEY'
# # find your environment next to the api key in pinecone console
# PINECONE_ENV = "us-west4-gcp" # os.environ.get('PINECONE_ENV') or 'PINECONE_ENV'
# pinecone.init(
# api_key=PINECONE_API_KEY,
# environment=PINECONE_ENV
# )
# # pinecone.delete_index(index_name)
# # only create index if it doesn't exist
# if index_name not in pinecone.list_indexes():
# pinecone.create_index(
# name=index_name,
# dimension=sentence_model.get_sentence_embedding_dimension(),
# metric='cosine'
# )
# # now connect to the index
# self.index = pinecone.GRPCIndex(index_name)
# return self.index
# def build_index(self):
# if self.index is None:
# index = self.init_pinecone()
# else:
# index = self.index
# if index.describe_index_stats()['total_vector_count']:
# "Index already built"
# return
# sentence_model = self.sm
# x = pd.read_excel('/kaggle/input/drug-p/Diseases_data_W.xlsx')
# question_dict = {'About': 'What is {}?', 'Symptoms': 'What are symptoms of {}?',
# 'Causes': 'What are causes of {}?',
# 'Diagnosis': 'What are diagnosis for {}?', 'Risk Factors': 'What are the risk factors for {}?',
# 'Treatment Options': 'What are the treatment options for {}?',
# 'Prognosis and Complications': 'What are the prognosis and complications?'}
# context = []
# disease_list = []
# for i in range(len(x)):
# disease = x.iloc[i, 0]
# if disease.strip().lower() in disease_list:
# continue
# disease_list.append(disease.strip().lower())
# conditions = x.iloc[i, 1:].dropna().index
# answers = x.iloc[i, 1:].dropna()
# for cond in conditions:
# context.append(f"{question_dict[cond].format(disease)}\n\n{answers[cond]}")
# batch_size = 128
# for i in tqdm(range(0, len(context), batch_size)):
# # find end of batch
# i_end = min(i + batch_size, len(context))
# # create IDs batch
# ids = [str(x) for x in range(i, i_end)]
# # create metadata batch
# metadatas = [{'text': text} for text in context[i:i_end]]
# # create embeddings
# xc = sentence_model.encode(context[i:i_end])
# # create records list for upsert
# records = zip(ids, xc, metadatas)
# # upsert to Pinecone
# index.upsert(vectors=records)
# # check number of records in the index
# index.describe_index_stats()
# def search(self, query: str = "medicines for fever"):
# sentence_model = self.sm
# if self.index is None:
# self.build_index()
# index = self.index
# # create the query vector
# xq = sentence_model.encode(query).tolist()
# # now query
# xc = index.query(xq, top_k = 3, include_metadata = True)
# return xc
# class QAModel():
# def __init__(self, checkpoint="google/flan-t5-xl"):
# self.checkpoint = checkpoint
# self.tmpdir = f"{self.checkpoint.split('/')[-1]}-sharded"
# def store_sharded_model(self):
# tmpdir = self.tmpdir
# checkpoint = self.checkpoint
# if not os.path.exists(tmpdir):
# os.mkdir(tmpdir)
# print(f"Directory created - {tmpdir}")
# model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
# print(f"Model loaded - {checkpoint}")
# model.save_pretrained(tmpdir, max_shard_size="200MB")
# def load_sharded_model(self):
# tmpdir = self.tmpdir
# if not os.path.exists(tmpdir):
# self.store_sharded_model()
# checkpoint = self.checkpoint
# config = AutoConfig.from_pretrained(checkpoint)
# tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# with init_empty_weights():
# model = AutoModelForSeq2SeqLM.from_config(config)
# # model = AutoModelForSeq2SeqLM.from_pretrained(tmpdir)
# model = load_checkpoint_and_dispatch(model, checkpoint=tmpdir, device_map="auto")
# return model, tokenizer
# def query_model(self, model, tokenizer, query):
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# return tokenizer.batch_decode(model.generate(**tokenizer(query, return_tensors='pt').to(device)), skip_special_tokens=True)[0]
PI = PineconeIndex()
PI.build_index()
qamodel = QAModel()
model, tokenizer = qamodel.load_sharded_model()
def request_answer(query):
search_results = PI.search(query)
answers = []
# print(search_results)
for r in search_results['matches']:
if r['score'] >= 0.45:
tokenized_context = tokenizer(r['metadata']['text'])
# query_to_model = f"""You are doctor who knows cures to diseases. If you don't know the answer, please refrain from providing answers that are not relevant to the context. Please suggest appropriate remedies based on the context provided.\n\nContext: {context}\n\n\nResponse: """
query_to_model = """You are doctor who knows cures to diseases. If you don't know, say you don't know. Please respond appropriately based on the context provided.\n\nContext: {}\n\n\nResponse: """
for ind in range(0, len(tokenized_context['input_ids']), 512-42):
decoded_tokens_for_context = tokenizer.batch_decode([tokenized_context['input_ids'][ind:ind+470]], skip_special_tokens=True)
response = qamodel.query_model(model, tokenizer, query_to_model.format(decoded_tokens_for_context[0]))
if not "don't know" in response:
answers.append(response)
if len(answers) == 0:
return 'Not enough information to answer the question'
return '\n'.join(answers)
demo = gr.Interface(
fn=request_answer,
inputs=[
gr.components.Textbox(label="User question(Response may take up to 2 mins because of hardware limitation)"),
],
outputs=[
gr.components.Textbox(label="Output (The answer is meant as a reference and not actual advice)"),
],
cache_examples=True,
title="MedQA assistant",
#description="MedQA assistant"
)
demo.launch() |