jwkirchenbauer commited on
Commit
dafc0b4
1 Parent(s): 6b83b21

beam field fix

Browse files
Files changed (1) hide show
  1. demo_watermark.py +2 -2
demo_watermark.py CHANGED
@@ -97,7 +97,7 @@ def parse_args():
97
  parser.add_argument(
98
  "--n_beams",
99
  type=int,
100
- choices=[1,4,8],
101
  help="Number of beams to use for beam search. 1 is normal greedy decoding",
102
  )
103
  parser.add_argument(
@@ -399,7 +399,7 @@ def run_gradio(args, model=None, device=None, tokenizer=None):
399
  return gr.update(visible=False)
400
  elif value == "greedy":
401
  return gr.update(visible=True)
402
- def update_n_beams(session_state, value): session_state.n_beams = int(value); return session_state
403
  def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
404
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
405
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state
 
97
  parser.add_argument(
98
  "--n_beams",
99
  type=int,
100
+ default=1,
101
  help="Number of beams to use for beam search. 1 is normal greedy decoding",
102
  )
103
  parser.add_argument(
 
399
  return gr.update(visible=False)
400
  elif value == "greedy":
401
  return gr.update(visible=True)
402
+ def update_n_beams(session_state, value): session_state.n_beams = value; return session_state
403
  def update_max_new_tokens(session_state, value): session_state.max_new_tokens = int(value); return session_state
404
  def update_ignore_repeated_bigrams(session_state, value): session_state.ignore_repeated_bigrams = value; return session_state
405
  def update_normalizers(session_state, value): session_state.normalizers = value; return session_state