Doron Adler commited on
Commit
b18bc66
1 Parent(s): 0a3b879

Maximum number of new tokens slider

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -13,12 +13,12 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
14
  model.to(device)
15
 
16
- def generate(text = ""):
17
  streamer = TextIteratorStreamer(tok, timeout=10.)
18
  if len(text) == 0:
19
  text = " "
20
  inputs = tok([text], return_tensors="pt").to(device)
21
- generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4)
22
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
23
  thread.start()
24
  generated_text = ""
@@ -35,8 +35,10 @@ def generate(text = ""):
35
  demo = gr.Interface(
36
  title="TextIteratorStreamer + Gradio demo",
37
  fn=generate,
38
- inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
39
- outputs=gr.outputs.Textbox(label="Generated Text"),
 
 
40
  )
41
 
42
  demo.queue()
 
13
  n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
14
  model.to(device)
15
 
16
+ def generate(text = "", max_new_tokens = 128):
17
  streamer = TextIteratorStreamer(tok, timeout=10.)
18
  if len(text) == 0:
19
  text = " "
20
  inputs = tok([text], return_tensors="pt").to(device)
21
+ generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=max_new_tokens, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4)
22
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
23
  thread.start()
24
  generated_text = ""
 
35
  demo = gr.Interface(
36
  title="TextIteratorStreamer + Gradio demo",
37
  fn=generate,
38
+ inputs=[gr.inputs.Textbox(lines=5, label="Input Text"),
39
+ gr.inputs.Slider(default=128,minimum=5, maximum=256, step=1, label="Maximum number of new tokens")],
40
+ outputs=gr.outputs.Textbox(label="Generated Text"),
41
+ allow_flagging="never"
42
  )
43
 
44
  demo.queue()