File size: 2,066 Bytes
c70da6e e648a11 ed738dc c70da6e 6e58fbc ed738dc 4619405 c70da6e f72c8c5 4619405 75a9553 4619405 f72c8c5 4619405 75a9553 4619405 c70da6e 8cc2393 4619405 f72c8c5 75a9553 f9da95f fdaa910 f72c8c5 75a9553 f72c8c5 4619405 f72c8c5 4619405 8cc2393 fdaa910 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import torch
trained_tokenizer = GPT2Tokenizer.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
trained_model = GPT2LMHeadModel.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
untrained_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
untrained_model = GPT2LMHeadModel.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
trained_model.to(device)
untrained_model.to(device)
def generate(commentary_text, max_length, temperature):
# Generate text using the finetuned model
input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
trained_output = trained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True, temperature=temperature)
trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
# Generate text using the base model
input_ids = untrained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
untrained_output = untrained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True,temperature=temperature)
untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
return trained_text, untrained_text
# Create Gradio interface
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
gr.Slider(minimum=10, maximum=100, value=50, step=1,label="Max Length"),
gr.Slider(minimum=0.01, maximum=2.0, value=0.7, label="Temperature")
],
outputs=[
gr.Textbox(label="commentary generation from finetuned GPT2 Model"),
gr.Textbox(label="commentary generation from base GPT2 Model")
],
title="GPT-2 Text Generation",
description="start writing a cricket commentary and GPT-2 will continue it using both a finetuned and base model."
)
# Launch the app
if __name__ == "__main__":
iface.launch() |