File size: 2,528 Bytes
c70da6e
e648a11
ed738dc
c70da6e
 
 
6e58fbc
 
ed738dc
4619405
 
c70da6e
f72c8c5
0a8b7e2
 
 
 
 
 
 
 
 
 
 
 
 
4619405
 
f72c8c5
0a8b7e2
 
 
 
 
 
 
 
 
 
 
 
4619405
 
 
c70da6e
8cc2393
4619405
 
f72c8c5
75a9553
f9da95f
1062c1a
f72c8c5
 
75a9553
 
f72c8c5
4619405
f72c8c5
4619405
8cc2393
 
 
fdaa910
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import gradio as gr
import torch

trained_tokenizer = GPT2Tokenizer.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
trained_model = GPT2LMHeadModel.from_pretrained("Kumarkishalaya/GPT-2-next-word-prediction")
untrained_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
untrained_model = GPT2LMHeadModel.from_pretrained("gpt2")
device = "cuda" if torch.cuda.is_available() else "cpu"
trained_model.to(device)
untrained_model.to(device)

def generate(commentary_text, max_length, temperature):
# Generate text using the finetuned model
    inputs = trained_tokenizer(commentary_text, return_tensors="pt", padding=True)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    trained_output = trained_model.generate(
        input_ids, 
        max_length=max_length, 
        num_beams=5, 
        do_sample=True, 
        temperature=temperature, 
        attention_mask=attention_mask,
        pad_token_id=trained_tokenizer.eos_token_id
    )
    trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
    
    # Generate text using the base model
    inputs = untrained_tokenizer(commentary_text, return_tensors="pt", padding=True)
    input_ids = inputs.input_ids.to(device)
    attention_mask = inputs.attention_mask.to(device)
    untrained_output = untrained_model.generate(
        input_ids, 
        max_length=max_length, 
        num_beams=5, 
        do_sample=True, 
        temperature=temperature, 
        attention_mask=attention_mask,
        pad_token_id=untrained_tokenizer.eos_token_id
    )
    untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
    
    return trained_text, untrained_text

# Create Gradio interface
iface = gr.Interface(
    fn=generate, 
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
        gr.Slider(minimum=10, maximum=100, value=50, step=1,label="Max Length"),        
        gr.Slider(minimum=0.01, maximum=1.99, value=0.7, label="Temperature")
    ], 
    outputs=[
        gr.Textbox(label="commentary generation from finetuned GPT2 Model"), 
        gr.Textbox(label="commentary generation from base GPT2 Model")
    ],
    title="GPT-2 Text Generation",
    description="start writing a cricket commentary and GPT-2 will continue it using both a finetuned and base model."
)

# Launch the app
if __name__ == "__main__":
    iface.launch()