Spaces:
Runtime error
Runtime error
import os | |
import re | |
import torch | |
import gradio as gr | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling | |
from transformers import Trainer, TrainingArguments | |
from transformers import AutoModelWithLMHead, AutoTokenizer | |
import warnings | |
warnings.filterwarnings('ignore') | |
my_repo = "medical-qa" | |
username = "pks3kor" # change it to your HuggingFace username | |
checkpoint = username + "/" + my_repo | |
# Load your model from hub | |
loaded_model = AutoModelWithLMHead.from_pretrained(checkpoint) | |
# Load tokenizer | |
loaded_tokenizer = AutoTokenizer.from_pretrained(checkpoint) | |
# Function for response generation | |
# fh = "All_QA.txt" | |
# def log_QA(prompt,ans): | |
# with open("All_QA.txt","a+") as fh: | |
# fh.write("Question:\n{}\n".format(prompt)) | |
# fh.write("Answer:\n{}\n".format(ans)) | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "medical_qa") | |
def generate_query_response(prompt, max_length=200): | |
model = loaded_model | |
tokenizer = loaded_tokenizer | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor | |
# Create the attention mask and pad token id | |
attention_mask = torch.ones_like(input_ids) | |
pad_token_id = tokenizer.eos_token_id | |
output = model.generate( | |
input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
attention_mask=attention_mask, | |
pad_token_id=pad_token_id | |
) | |
op = tokenizer.decode(output[0], skip_special_tokens=True) | |
# log_QA(prompt,op) | |
return op | |
# Gradio interface to generate UI link | |
iface = gr.Interface(fn=generate_query_response,inputs="textbox",outputs="textbox",title="Medical_Q&A",description="via gradio", allow_flagging="manual",flagging_options=["Correct","Wrong"],flagging_callback=hf_writer) | |
# iface = gr.Interface(fn=generate_query_response,inputs="textbox",outputs="textbox",title="Medical_Q&A",description="via gradio", allow_flagging="auto",flagging_callback=hf_writer) | |
iface.launch() |