egosumkira commited on
Commit
265eae6
1 Parent(s): ea58b06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -13,10 +13,10 @@ story = pipeline(
13
  )
14
 
15
 
16
- def generate(tags_text, temp=1.0, n_beams=3):
17
  tags = tags_text.split(", ")
18
  prefix = f"~^{'^'.join(tags)}~@"
19
- g_text = story(prefix, temperature=1.0, repetition_penalty=7.0, num_beams=3)[0]['generated_text']
20
  return g_text[g_text.find("@") + 1:]
21
 
22
 
@@ -30,7 +30,8 @@ iface = gr.Interface(generate,
30
  inputs = [
31
  gr.Textbox(label="Keywords (comma separated)"),
32
  gr.inputs.Slider(0, 2, default=1.0, step=0.05, label="Temperature"),
33
- gr.inputs.Slider(1, 10, default=3, label="Number of beams", step=1)
 
34
  ],
35
  outputs = gr.Textbox(label="Output"),
36
  title=title,
 
13
  )
14
 
15
 
16
+ def generate(tags_text, temp, n_beams, max_len):
17
  tags = tags_text.split(", ")
18
  prefix = f"~^{'^'.join(tags)}~@"
19
+ g_text = story(prefix, temperature=temp, repetition_penalty=7.0, num_beams=n_beams, max_length=max_len)[0]['generated_text']
20
  return g_text[g_text.find("@") + 1:]
21
 
22
 
 
30
  inputs = [
31
  gr.Textbox(label="Keywords (comma separated)"),
32
  gr.inputs.Slider(0, 2, default=1.0, step=0.05, label="Temperature"),
33
+ gr.inputs.Slider(1, 10, default=3, label="Number of beams", step=1),
34
+ gr.Number(label="Max lenght", value=128)
35
  ],
36
  outputs = gr.Textbox(label="Output"),
37
  title=title,