joaogante HF staff commited on
Commit
fed0a26
·
1 Parent(s): 46f6023

committing broken state

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -1,7 +1,7 @@
1
  from collections.abc import Sequence
2
  import json
3
  import random
4
- from typing import Optional
5
 
6
  import gradio as gr
7
  import spaces
@@ -12,7 +12,7 @@ import transformers
12
  # the nature of the task (e.g., fatcual responses are lower entropy) or it could
13
  # be another
14
 
15
- _MODEL_IDENTIFIER = 'google/gemma-2b'
16
  _DETECTOR_IDENTIFIER = 'gg-hf/detector_2b_1.0_demo'
17
 
18
  _PROMPTS: tuple[str] = (
@@ -21,11 +21,10 @@ _PROMPTS: tuple[str] = (
21
  'prompt 3',
22
  )
23
 
24
- _CORRECT_ANSWERS: dict[str, bool] = {}
25
-
26
  _TORCH_DEVICE = (
27
  torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
28
  )
 
29
 
30
  _WATERMARK_CONFIG_DICT = dict(
31
  ngram_len=5,
@@ -70,7 +69,7 @@ _WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig(
70
  **_WATERMARK_CONFIG_DICT
71
  )
72
 
73
- tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER)
74
  tokenizer.pad_token_id = tokenizer.eos_token_id
75
 
76
  model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER)
@@ -89,6 +88,7 @@ detector_module.to(_TORCH_DEVICE)
89
  detector = transformers.generation.watermarking.SynthIDTextWatermarkDetector(
90
  detector_module=detector_module,
91
  logits_processor=logits_processor,
 
92
  )
93
 
94
 
@@ -98,8 +98,9 @@ def generate_outputs(
98
  watermarking_config: Optional[
99
  transformers.generation.SynthIDTextWatermarkingConfig
100
  ] = None,
101
- ) -> Sequence[str]:
102
- tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE)
 
103
  output_sequences = model.generate(
104
  **tokenized_prompts,
105
  watermarking_config=watermarking_config,
@@ -107,9 +108,10 @@ def generate_outputs(
107
  max_length=500,
108
  top_k=40,
109
  )
 
110
  detections = detector(output_sequences)
111
  print(detections)
112
- return tokenizer.batch_decode(output_sequences)
113
 
114
 
115
  with gr.Blocks() as demo:
@@ -236,25 +238,33 @@ with gr.Blocks() as demo:
236
  detect_btn = gr.Button('Detect', visible=False)
237
 
238
  def generate(*prompts):
239
- standard = generate_outputs(prompts=prompts)
240
- watermarked = generate_outputs(
241
  prompts=prompts,
242
  watermarking_config=_WATERMARK_CONFIG,
243
  )
244
- responses = standard + watermarked
245
- random.shuffle(responses)
 
 
 
 
 
 
 
 
246
 
247
- _CORRECT_ANSWERS.update({
248
- response: response in watermarked
249
- for response in responses
250
- })
251
 
252
  # Load model
253
  return {
254
  generate_btn: gr.Button(visible=False),
255
  generations_col: gr.Column(visible=True),
256
  generations_grp: gr.CheckboxGroup(
257
- responses,
258
  ),
259
  reveal_btn: gr.Button(visible=True),
260
  }
@@ -269,17 +279,17 @@ with gr.Blocks() as demo:
269
  choices: list[str] = []
270
  value: list[str] = []
271
 
272
- for response, is_watermarked in _CORRECT_ANSWERS.items():
273
- if is_watermarked and response in user_selections:
274
- choice = f'Correct! {response}'
275
- elif not is_watermarked and response not in user_selections:
276
- choice = f'Correct! {response}'
 
 
277
  else:
278
  choice = f'Incorrect. {response}'
279
 
280
  choices.append(choice)
281
- if is_watermarked:
282
- value.append(choice)
283
 
284
  return {
285
  reveal_btn: gr.Button(visible=False),
 
1
  from collections.abc import Sequence
2
  import json
3
  import random
4
+ from typing import Optional, Tuple
5
 
6
  import gradio as gr
7
  import spaces
 
12
  # the nature of the task (e.g., fatcual responses are lower entropy) or it could
13
  # be another
14
 
15
+ _MODEL_IDENTIFIER = 'google/gemma-2b-it'
16
  _DETECTOR_IDENTIFIER = 'gg-hf/detector_2b_1.0_demo'
17
 
18
  _PROMPTS: tuple[str] = (
 
21
  'prompt 3',
22
  )
23
 
 
 
24
  _TORCH_DEVICE = (
25
  torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
26
  )
27
+ _ANSWERS = []
28
 
29
  _WATERMARK_CONFIG_DICT = dict(
30
  ngram_len=5,
 
69
  **_WATERMARK_CONFIG_DICT
70
  )
71
 
72
+ tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER, padding_side="left")
73
  tokenizer.pad_token_id = tokenizer.eos_token_id
74
 
75
  model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER)
 
88
  detector = transformers.generation.watermarking.SynthIDTextWatermarkDetector(
89
  detector_module=detector_module,
90
  logits_processor=logits_processor,
91
+ tokenizer=tokenizer,
92
  )
93
 
94
 
 
98
  watermarking_config: Optional[
99
  transformers.generation.SynthIDTextWatermarkingConfig
100
  ] = None,
101
+ ) -> Tuple[Sequence[str], torch.Tensor]:
102
+ tokenized_prompts = tokenizer(prompts, return_tensors='pt', padding="longest").to(_TORCH_DEVICE)
103
+ input_length = tokenized_prompts.input_ids.shape[1]
104
  output_sequences = model.generate(
105
  **tokenized_prompts,
106
  watermarking_config=watermarking_config,
 
108
  max_length=500,
109
  top_k=40,
110
  )
111
+ output_sequences = output_sequences[:, input_length:]
112
  detections = detector(output_sequences)
113
  print(detections)
114
+ return (tokenizer.batch_decode(output_sequences, skip_special_tokens=True), detections)
115
 
116
 
117
  with gr.Blocks() as demo:
 
238
  detect_btn = gr.Button('Detect', visible=False)
239
 
240
  def generate(*prompts):
241
+ standard, standard_detector = generate_outputs(prompts=prompts)
242
+ watermarked, watermarked_detector = generate_outputs(
243
  prompts=prompts,
244
  watermarking_config=_WATERMARK_CONFIG,
245
  )
246
+ upper_threshold = 0.9501
247
+ lower_threshold = 0.1209
248
+
249
+ def decision(score: float) -> str:
250
+ if score > upper_threshold:
251
+ return 'Watermarked'
252
+ elif lower_threshold < score < upper_threshold:
253
+ return 'Indeterminate'
254
+ else:
255
+ return 'Not watermarked'
256
 
257
+ responses = [(text, decision(score)) for text, score in zip(standard, standard_detector[0])]
258
+ responses += [(text, decision(score)) for text, score in zip(watermarked, watermarked_detector[0])]
259
+ random.shuffle(responses)
260
+ _ANSWERS = responses
261
 
262
  # Load model
263
  return {
264
  generate_btn: gr.Button(visible=False),
265
  generations_col: gr.Column(visible=True),
266
  generations_grp: gr.CheckboxGroup(
267
+ [response[0] for response in responses],
268
  ),
269
  reveal_btn: gr.Button(visible=True),
270
  }
 
279
  choices: list[str] = []
280
  value: list[str] = []
281
 
282
+ for (response, decision) in _ANSWERS:
283
+ if decision == "Watermarked":
284
+ value.append(choice)
285
+ if response in user_selections:
286
+ choice = f'Correct! {response}
287
+ elif decision == 'Indeterminate':
288
+ choice = f'Uncertain! {response}'
289
  else:
290
  choice = f'Incorrect. {response}'
291
 
292
  choices.append(choice)
 
 
293
 
294
  return {
295
  reveal_btn: gr.Button(visible=False),