Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel, TextDataset, DataCollatorForLanguageModeling | |
from transformers import Trainer, TrainingArguments | |
__checkpoint = "gpt2" | |
__tokenizer = GPT2Tokenizer.from_pretrained(__checkpoint) | |
__model = GPT2LMHeadModel.from_pretrained(__checkpoint) | |
# Create a Data collator object | |
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt") | |
def queryGPT(question): | |
return generate_response(__model, __tokenizer, question) | |
def generate_response(__model, __tokenizer, prompt, max_length=200): | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") # 'pt' for returning pytorch tensor | |
# Create the attention mask and pad token id | |
attention_mask = torch.ones_like(input_ids) | |
pad_token_id = __tokenizer.eos_token_id | |
output = __model.generate( | |
input_ids, | |
max_length=max_length, | |
num_return_sequences=1, | |
attention_mask=attention_mask, | |
pad_token_id=pad_token_id | |
) | |
return __tokenizer.decode(output[0], skip_special_tokens=True) | |
with gr.Blocks() as demo: | |
txt_input = gr.Textbox(label="Input Question", lines=2) | |
txt_output = gr.Textbox(value="", label="Answer") | |
btn = gr.Button(value="Submit") | |
btn.click(queryGPT, inputs=[txt_input], outputs=[txt_output]) | |
if __name__ == "__main__": | |
demo.launch() |