paavansundar's picture
Update app.py
68169bd
raw
history blame
No virus
1.67 kB
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
__checkpoint = "gpt2"
__tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint)
__model = GPT2LMHeadModel.from_pretrained(__checkpoint)
__model_output_path = "paavansundar/Medical_QNA_GPT2"
#prepare data
def prepareData():
df=pd.read_csv("paavansundar/medquadqna/MedQuAD.csv")
# Create a Data collator object
data_collator = DataCollatorForLanguageModeling(tokenizer=__tokenizer, mlm=False, return_tensors="pt")
def queryGPT(question):
return generate_response(__model, __tokenizer, question)
def generate_response(model,tokenizer, prompt, max_length=200):
prepareData()
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
# Create the attention mask and pad token id
attention_mask = torch.ones_like(input_ids)
pad_token_id = tokenizer.eos_token_id
output = model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
attention_mask=attention_mask,
pad_token_id=pad_token_id
)
return tokenizer.decode(output[0], skip_special_tokens=True)
with gr.Blocks() as demo:
txt_input = gr.Textbox(label="Input Question", lines=2)
txt_output = gr.Textbox(value="", label="Answer")
btn = gr.Button(value="Submit")
btn.click(queryGPT, inputs=[txt_input], outputs=[txt_output])
if __name__ == "__main__":
demo.launch()