File size: 2,680 Bytes
c70da6e e648a11 ed738dc c70da6e 6e58fbc 24dbe3e ed738dc 4619405 c70da6e 24dbe3e f72c8c5 0a8b7e2 4619405 f72c8c5 0a8b7e2 4619405 c70da6e 8cc2393 4619405 f72c8c5 75a9553 f9da95f 1062c1a f72c8c5 75a9553 f72c8c5 4619405 f72c8c5 4619405 8cc2393 fdaa910 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import torch
trained_tokenizer = GPT2Tokenizer.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
trained_model = GPT2LMHeadModel.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
untrained_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
untrained_model = GPT2LMHeadModel.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
trained_model.to(device)
untrained_model.to(device)
# Set pad_token to eos_token
trained_tokenizer.pad_token = trained_tokenizer.eos_token
untrained_tokenizer.pad_token = untrained_tokenizer.eos_token
def generate(commentary_text, max_length, temperature):
# Generate text using the finetuned model
inputs = trained_tokenizer(commentary_text, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
trained_output = trained_model.generate(
input_ids,
max_length=max_length,
num_beams=5,
do_sample=True,
temperature=temperature,
attention_mask=attention_mask,
pad_token_id=trained_tokenizer.eos_token_id
)
trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
# Generate text using the base model
inputs = untrained_tokenizer(commentary_text, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
untrained_output = untrained_model.generate(
input_ids,
max_length=max_length,
num_beams=5,
do_sample=True,
temperature=temperature,
attention_mask=attention_mask,
pad_token_id=untrained_tokenizer.eos_token_id
)
untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
return trained_text, untrained_text
# Create Gradio interface
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
gr.Slider(minimum=10, maximum=100, value=50, step=1,label="Max Length"),
gr.Slider(minimum=0.01, maximum=1.99, value=0.7, label="Temperature")
],
outputs=[
gr.Textbox(label="commentary generation from finetuned GPT2 Model"),
gr.Textbox(label="commentary generation from base GPT2 Model")
],
title="GPT-2 Text Generation",
description="start writing a cricket commentary and GPT-2 will continue it using both a finetuned and base model."
)
# Launch the app
if __name__ == "__main__":
iface.launch() |