paavansundar's picture
Update app.py
01e535b
raw
history blame
1.41 kB
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments
__checkpoint = "gpt2"
__tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint)
__model = GPT2LMHeadModel.from_pretrained(__checkpoint)
# Create a Data collator object
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
def queryGPT(question):
return generate_response(__model, __tokenizer, question)
def generate_response(__model, __tokenizer, prompt, max_length=200):
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor
# Create the attention mask and pad token id
attention_mask = torch.ones_like(input_ids)
pad_token_id = __tokenizer.eos_token_id
output = __model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
attention_mask=attention_mask,
pad_token_id=pad_token_id
)
return __tokenizer.decode(output[0], skip_special_tokens=True)
with gr.Blocks() as demo:
txt_input = gr.Textbox(label="Input Question", lines=2)
txt_output = gr.Textbox(value="", label="Answer")
btn = gr.Button(value="Submit")
btn.click(queryGPT, inputs=[txt_input], outputs=[txt_output])
if __name__ == "__main__":
demo.launch()