Kumarkishalaya commited on
Commit
f72c8c5
1 Parent(s): 6e58fbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -10,15 +10,15 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
10
  trained_model.to(device)
11
  untrained_model.to(device)
12
 
13
- def generate(commentary_text):
14
- # Generate text using the trained model
15
  input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
16
- trained_output = trained_model.generate(input_ids, max_length=60, num_beams=5, do_sample=False)
17
  trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
18
 
19
- # Generate text using the untrained model
20
  input_ids = untrained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
21
- untrained_output = untrained_model.generate(input_ids, max_length=60, num_beams=5, do_sample=False)
22
  untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
23
 
24
  return trained_text, untrained_text
@@ -26,10 +26,17 @@ def generate(commentary_text):
26
  # Create Gradio interface
27
  iface = gr.Interface(
28
  fn=generate,
29
- inputs="text",
30
- outputs=["text", "text"],
 
 
 
 
 
 
 
31
  title="GPT-2 Text Generation",
32
- description="start writing a cricket commentary and GPT-2 will continue it using both a trained and untrained model."
33
  )
34
 
35
  # Launch the app
 
10
  trained_model.to(device)
11
  untrained_model.to(device)
12
 
13
+ def generate(commentary_text, max_length, temperature):
14
+ # Generate text using the finetuned model
15
  input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
16
+ trained_output = trained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=False, temperature=temperature)
17
  trained_text = trained_tokenizer.decode(trained_output[0], skip_special_tokens=True)
18
 
19
+ # Generate text using the base model
20
  input_ids = untrained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
21
+ untrained_output = untrained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=False,temperature=temperature)
22
  untrained_text = untrained_tokenizer.decode(untrained_output[0], skip_special_tokens=True)
23
 
24
  return trained_text, untrained_text
 
26
  # Create Gradio interface
27
  iface = gr.Interface(
28
  fn=generate,
29
+ inputs=[
30
+ gr.inputs.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
31
+ gr.inputs.Slider(minimum=10, maximum=100, default=50, label="Max Length"),
32
+ gr.inputs.Slider(minimum=0.1, maximum=1.0, default=0.7, label="Temperature")
33
+ ],
34
+ outputs=[
35
+ gr.outputs.Textbox(label="commentary generation from finetuned GPT2 Model"),
36
+ gr.outputs.Textbox(label="commentary generation from base GPT2 Model")
37
+ ],
38
  title="GPT-2 Text Generation",
39
+ description="start writing a cricket commentary and GPT-2 will continue it using both a finetuned and base model."
40
  )
41
 
42
  # Launch the app