medical_qa / app.py
pks3kor's picture
Update app.py
9eca8a6 verified
import os
import re
import torch
import gradio as gr
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
from transformers import AutoModelWithLMHead, AutoTokenizer
import warnings
warnings.filterwarnings('ignore')
my_repo = "medical-qa"
username = "pks3kor" # change it to your HuggingFace username
checkpoint = username + "/" + my_repo
# Load your model from hub
loaded_model = AutoModelWithLMHead.from_pretrained(checkpoint)
# Load tokenizer
loaded_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# Function for response generation
# fh = "All_QA.txt"
# def log_QA(prompt,ans):
# with open("All_QA.txt","a+") as fh:
# fh.write("Question:\n{}\n".format(prompt))
# fh.write("Answer:\n{}\n".format(ans))
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "medical_qa")
def generate_query_response(prompt, max_length=200):
model = loaded_model
tokenizer = loaded_tokenizer
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
# Create the attention mask and pad token id
attention_mask = torch.ones_like(input_ids)
pad_token_id = tokenizer.eos_token_id
output = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
attention_mask=attention_mask,
pad_token_id=pad_token_id
)
op = tokenizer.decode(output[0], skip_special_tokens=True)
# log_QA(prompt,op)
return op
# Gradio interface to generate UI link
iface = gr.Interface(fn=generate_query_response,inputs="textbox",outputs="textbox",title="Medical_Q&A",description="via gradio", allow_flagging="manual",flagging_options=["Correct","Wrong"],flagging_callback=hf_writer)
# iface = gr.Interface(fn=generate_query_response,inputs="textbox",outputs="textbox",title="Medical_Q&A",description="via gradio", allow_flagging="auto",flagging_callback=hf_writer)
iface.launch()