File size: 2,066 Bytes
c70da6e
e648a11
ed738dc
c70da6e
 
 
6e58fbc
 
ed738dc
4619405
 
c70da6e
f72c8c5
 
4619405
75a9553
4619405
 
f72c8c5
4619405
75a9553
4619405
 
 
c70da6e
8cc2393
4619405
 
f72c8c5
75a9553
f9da95f
fdaa910
f72c8c5
 
75a9553
 
f72c8c5
4619405
f72c8c5
4619405
8cc2393
 
 
fdaa910
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import torch

trained_tokenizer = GPT2Tokenizer.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
trained_model = GPT2LMHeadModel.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
untrained_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
untrained_model = GPT2LMHeadModel.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
trained_model.to(device)
untrained_model.to(device)

def generate(commentary_text, max_length, temperature):
    # Generate text using the finetuned model
    input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
    trained_output = trained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True, temperature=temperature)
    trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
    
    # Generate text using the base model
    input_ids = untrained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
    untrained_output = untrained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True,temperature=temperature)
    untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
    
    return trained_text, untrained_text

# Create Gradio interface
iface = gr.Interface(
    fn=generate, 
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
        gr.Slider(minimum=10, maximum=100, value=50, step=1,label="Max Length"),        
        gr.Slider(minimum=0.01, maximum=2.0, value=0.7, label="Temperature")
    ], 
    outputs=[
        gr.Textbox(label="commentary generation from finetuned GPT2 Model"), 
        gr.Textbox(label="commentary generation from base GPT2 Model")
    ],
    title="GPT-2 Text Generation",
    description="start writing a cricket commentary and GPT-2 will continue it using both a finetuned and base model."
)

# Launch the app
if __name__ == "__main__":
    iface.launch()