File size: 2,021 Bytes
1dd74c6
d99f88f
4e44d93
 
1dd74c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b605e1
 
e154f23
6b605e1
1dd74c6
 
 
063a814
1dd74c6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import gradio as gr
import os
from pinecone_integration import PineconeIndex
from qa_model import QAModel


PI = PineconeIndex()
PI.build_index()
qamodel = QAModel()
model, tokenizer = qamodel.load_sharded_model()

def request_answer(query):
    search_results = PI.search(query)
    answers = []
    # print(search_results)
    for r in search_results['matches']:
        if r['score'] >= 0.45:
            tokenized_context = tokenizer(r['metadata']['text'])
#             query_to_model = f"""You are doctor who knows cures to diseases. If you don't know the answer, please refrain from providing answers that are not relevant to the context. Please suggest appropriate remedies based on the context provided.\n\nContext: {context}\n\n\nResponse: """
            query_to_model = """You are doctor who knows cures to diseases. If you don't know, say you don't know. Please respond appropriately based on the context provided.\n\nContext: {}\n\n\nResponse: """
            for ind in range(0, len(tokenized_context['input_ids']), 512-42):                        
                decoded_tokens_for_context = tokenizer.batch_decode([tokenized_context['input_ids'][ind:ind+470]], skip_special_tokens=True)
                response = qamodel.query_model(model, tokenizer, query_to_model.format(decoded_tokens_for_context[0]))
                
                if not "don't know" in response:
                    answers.append(response)

    if len(answers) == 0:
        return 'Not enough information to answer the question'
    return '\n'.join(answers)


demo = gr.Interface(
    fn=request_answer,
    inputs=[
        gr.components.Textbox(label="User question(Response may take up to 2 mins because of hardware limitation)"),
    ],
    outputs=[
        gr.components.Textbox(label="Output (The answer is meant as a reference and not actual advice)"),
    ],
    cache_examples=True,
    title="MedQA assistant",
    description='Check out the repository at: https://github.com/anandshah98/MedQA',
)

demo.launch()