Commit
•
1062c1a
1
Parent(s):
f9da95f
Update app.py
Browse files
app.py
CHANGED
@@ -11,6 +11,9 @@ trained_model.to(device)
|
|
11 |
untrained_model.to(device)
|
12 |
|
13 |
def generate(commentary_text, max_length, temperature):
|
|
|
|
|
|
|
14 |
# Generate text using the finetuned model
|
15 |
input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
|
16 |
trained_output = trained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True, temperature=temperature)
|
@@ -29,7 +32,7 @@ iface = gr.Interface(
|
|
29 |
inputs=[
|
30 |
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
|
31 |
gr.Slider(minimum=10, maximum=100, value=50, step=1,label="Max Length"),
|
32 |
-
gr.Slider(minimum=0.01, maximum=
|
33 |
],
|
34 |
outputs=[
|
35 |
gr.Textbox(label="commentary generation from finetuned GPT2 Model"),
|
|
|
11 |
untrained_model.to(device)
|
12 |
|
13 |
def generate(commentary_text, max_length, temperature):
|
14 |
+
if temperature <= 0:
|
15 |
+
return "Error: Temperature must be a strictly positive float.", "Error: Temperature must be a strictly positive float."
|
16 |
+
|
17 |
# Generate text using the finetuned model
|
18 |
input_ids = trained_tokenizer(commentary_text, return_tensors="pt").input_ids.to(device)
|
19 |
trained_output = trained_model.generate(input_ids, max_length=max_length, num_beams=5, do_sample=True, temperature=temperature)
|
|
|
32 |
inputs=[
|
33 |
gr.Textbox(lines=2, placeholder="Enter your prompt here...", label="Prompt"),
|
34 |
gr.Slider(minimum=10, maximum=100, value=50, step=1,label="Max Length"),
|
35 |
+
gr.Slider(minimum=0.01, maximum=1.99, value=0.7, label="Temperature")
|
36 |
],
|
37 |
outputs=[
|
38 |
gr.Textbox(label="commentary generation from finetuned GPT2 Model"),
|