jwkirchenbauer commited on
Commit
a134a9d
1 Parent(s): b5b3015

fixed args

Browse files
Files changed (1) hide show
  1. demo_watermark.py +16 -6
demo_watermark.py CHANGED
@@ -157,7 +157,7 @@ def parse_args():
157
  args = parser.parse_args()
158
  return args
159
 
160
- def load_model():
161
  args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
162
  args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
163
  if args.is_seq2seq_model:
@@ -178,7 +178,7 @@ def load_model():
178
 
179
  return model, tokenizer, device
180
 
181
- def generate(prompt, args, model=None, tokenizer=None):
182
 
183
  print(f"Generating with {args}")
184
 
@@ -261,7 +261,7 @@ def detect(input_text, args, device=None, tokenizer=None):
261
 
262
  def run_gradio(args, model=None, device=None, tokenizer=None):
263
 
264
- generate_partial = partial(generate, model=model, tokenizer=tokenizer)
265
  detect_partial = partial(detect, device=device, tokenizer=tokenizer)
266
 
267
  with gr.Blocks() as demo:
@@ -447,9 +447,19 @@ def main(args):
447
  print("Prompt:")
448
  print(input_text)
449
 
450
- _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text, args)
451
- without_watermark_detection_result = detect(decoded_output_without_watermark, args)
452
- with_watermark_detection_result = detect(decoded_output_with_watermark, args)
 
 
 
 
 
 
 
 
 
 
453
 
454
  print("#"*term_width)
455
  print("Output without watermark:")
 
157
  args = parser.parse_args()
158
  return args
159
 
160
+ def load_model(args):
161
  args.is_seq2seq_model = any([(model_type in args.model_name_or_path) for model_type in ["t5","T0"]])
162
  args.is_decoder_only_model = any([(model_type in args.model_name_or_path) for model_type in ["gpt","opt","bloom"]])
163
  if args.is_seq2seq_model:
 
178
 
179
  return model, tokenizer, device
180
 
181
+ def generate(prompt, args, model=None, device=None, tokenizer=None):
182
 
183
  print(f"Generating with {args}")
184
 
 
261
 
262
  def run_gradio(args, model=None, device=None, tokenizer=None):
263
 
264
+ generate_partial = partial(generate, model=model, device=None, tokenizer=tokenizer)
265
  detect_partial = partial(detect, device=device, tokenizer=tokenizer)
266
 
267
  with gr.Blocks() as demo:
 
447
  print("Prompt:")
448
  print(input_text)
449
 
450
+ _, _, decoded_output_without_watermark, decoded_output_with_watermark, _ = generate(input_text,
451
+ args,
452
+ model=model,
453
+ device=device,
454
+ tokenizer=tokenizer)
455
+ without_watermark_detection_result = detect(decoded_output_without_watermark,
456
+ args,
457
+ device=device,
458
+ tokenizer=tokenizer)
459
+ with_watermark_detection_result = detect(decoded_output_with_watermark,
460
+ args,
461
+ device=device,
462
+ tokenizer=tokenizer)
463
 
464
  print("#"*term_width)
465
  print("Output without watermark:")