from collections.abc import Sequence import json import random from typing import Optional import gradio as gr import spaces import torch import transformers # If the watewrmark is not detected, consider the use case. Could be because of # the nature of the task (e.g., fatcual responses are lower entropy) or it could # be another _MODEL_IDENTIFIER = 'google/gemma-2b' _DETECTOR_IDENTIFIER = 'gg-hf/detector_2b_1.0_demo' _PROMPTS: tuple[str] = ( 'prompt 1', 'prompt 2', 'prompt 3', ) _CORRECT_ANSWERS: dict[str, bool] = {} _TORCH_DEVICE = ( torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") ) _WATERMARK_CONFIG_DICT = dict( ngram_len=5, keys=[ 654, 400, 836, 123, 340, 443, 597, 160, 57, 29, 590, 639, 13, 715, 468, 990, 966, 226, 324, 585, 118, 504, 421, 521, 129, 669, 732, 225, 90, 960, ], sampling_table_size=2**16, sampling_table_seed=0, context_history_size=1024, ) _WATERMARK_CONFIG = transformers.generation.SynthIDTextWatermarkingConfig( **_WATERMARK_CONFIG_DICT ) tokenizer = transformers.AutoTokenizer.from_pretrained(_MODEL_IDENTIFIER) tokenizer.pad_token_id = tokenizer.eos_token_id model = transformers.AutoModelForCausalLM.from_pretrained(_MODEL_IDENTIFIER) model.to(_TORCH_DEVICE) logits_processor = transformers.generation.SynthIDTextWatermarkLogitsProcessor( **_WATERMARK_CONFIG_DICT, device=_TORCH_DEVICE, ) detector_module = transformers.generation.BayesianDetectorModel.from_pretrained( _DETECTOR_IDENTIFIER, ) detector_module.to(_TORCH_DEVICE) detector = transformers.generation.watermarking.SynthIDTextWatermarkDetector( detector_module=detector_module, logits_processor=logits_processor, ) @spaces.GPU def generate_outputs( prompts: Sequence[str], watermarking_config: Optional[ transformers.generation.SynthIDTextWatermarkingConfig ] = None, ) -> Sequence[str]: tokenized_prompts = tokenizer(prompts, return_tensors='pt').to(_TORCH_DEVICE) output_sequences = model.generate( **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, max_length=500, top_k=40, ) detections = detector(output_sequences) print(detections) return tokenizer.batch_decode(output_sequences) with gr.Blocks() as demo: gr.Markdown( f''' # Using SynthID Text in your Genreative AI projects [SynthID][synthid] is a Google DeepMind technology that watermarks and identifies AI-generated content by embedding digital watermarks directly into AI-generated images, audio, text or video. SynthID Text is an open source implementation of this technology available in Hugging Face Transformers that has two major components: * A [logits processor][synthid-hf-logits-processor] that is [configured][synthid-hf-config] on a per-model basis and activated when calling `.generate()`; and * A [detector][synthid-hf-detector] trained to recognized watermarked text generated by a specific model with a specific configuraiton. This Space demonstrates: 1. How to use SynthID Text to apply a watermark to text generated by your model; and 1. How to indetify that text using a ready-made detector. Note that this detector is trained specifically fore this demonstration. You should maintain a specific watermarking configuration for every model you use and protect that configuration as you would any other secret. See the [end-to-end guide][synthid-hf-detector-e2e] for more on training your own detectors, and the [SynthID Text documentaiton][raitk-synthid] for more on how this technology works. ## Getting started Practically speaking, SynthID Text is a logits processor, applied to your model's generation pipeline after [Top-K and Top-P][cloud-parameter-values], that augments the model's logits using a pseudorandom _g_-function to encode watermarking information in a way that balances generation quality with watermark detectability. See the [paper][synthid-nature] for a complete technical description of the algorithm and analyses of how different configuration values affect performance. Watermarks are [configured][synthid-hf-config] to parameterize the _g_-function and how it is applied during generation. We use the following configuration for all demos. It should not be used for any production purposes. ```json {json.dumps(_WATERMARK_CONFIG_DICT)} ``` Watermarks are applied by initializing a `SynthIDTextWatermarkingConfig` and passing that as the `watermarking_config=` parameter in your call to `.generate()`, as shown in the snippet below. ```python from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.generation import SynthIDTextWatermarkingConfig # Standard model and toeknizer initialization tokenizer = AutoTokenizer.from_pretrained('repo/id') model = AutoModelForCausalLM.from_pretrained('repo/id') # SynthID Text configuration watermarking_config = SynthIDTextWatermarkingConfig(...) # Generation with watermarking tokenized_prompts = tokenizer(["your prompts here"]) output_sequences = model.generate( **tokenized_prompts, watermarking_config=watermarking_config, do_sample=True, ) watermarked_text = tokenizer.batch_decode(output_sequences) ``` Enter up to three prompts then click the generate button. After you click, [Gemma 2B][gemma] will generate a watermarked and non-watermarked repsonses for each non-empty prompt. [cloud-parameter-values]: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/prompts/adjust-parameter-values [gemma]: https://huggingface.co/google/gemma-2b [raitk-synthid]: /responsible/docs/safeguards/synthid [synthid]: https://deepmind.google/technologies/synthid/ [synthid-hf-config]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/configuration_utils.py [synthid-hf-detector]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/watermarking.py [synthid-hf-detector-e2e]: https://github.com/huggingface/transformers/blob/v4.46.0/examples/research_projects/synthid_text/detector_bayesian.py [synthid-hf-logits-processor]: https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/generation/logits_process.py [synthid-nature]: https://www.nature.com/articles/s41586-024-08025-4 ''' ) prompt_inputs = [ gr.Textbox(value=prompt, lines=4, label='Prompt') for prompt in _PROMPTS ] generate_btn = gr.Button('Generate') with gr.Column(visible=False) as generations_col: gr.Markdown( ''' # SynthID: Tool ''' ) generations_grp = gr.CheckboxGroup( label='All generations, in random order', info='Select the generations you think are watermarked!', ) reveal_btn = gr.Button('Reveal', visible=False) with gr.Column(visible=False) as detections_col: gr.Markdown( ''' # SynthID: Tool ''' ) revealed_grp = gr.CheckboxGroup( label='Ground truth for all generations', info=( 'Watermarked generations are checked, and your selection are ' 'marked as correct or incorrect in the text.' ), ) detect_btn = gr.Button('Detect', visible=False) def generate(*prompts): standard = generate_outputs(prompts=prompts) watermarked = generate_outputs( prompts=prompts, watermarking_config=_WATERMARK_CONFIG, ) responses = standard + watermarked random.shuffle(responses) _CORRECT_ANSWERS.update({ response: response in watermarked for response in responses }) # Load model return { generate_btn: gr.Button(visible=False), generations_col: gr.Column(visible=True), generations_grp: gr.CheckboxGroup( responses, ), reveal_btn: gr.Button(visible=True), } generate_btn.click( generate, inputs=prompt_inputs, outputs=[generate_btn, generations_col, generations_grp, reveal_btn] ) def reveal(user_selections: list[str]): choices: list[str] = [] value: list[str] = [] for response, is_watermarked in _CORRECT_ANSWERS.items(): if is_watermarked and response in user_selections: choice = f'Correct! {response}' elif not is_watermarked and response not in user_selections: choice = f'Correct! {response}' else: choice = f'Incorrect. {response}' choices.append(choice) if is_watermarked: value.append(choice) return { reveal_btn: gr.Button(visible=False), detections_col: gr.Column(visible=True), revealed_grp: gr.CheckboxGroup(choices=choices, value=value), detect_btn: gr.Button(visible=True), } reveal_btn.click( reveal, inputs=generations_grp, outputs=[ reveal_btn, detections_col, revealed_grp, detect_btn ], ) if __name__ == '__main__': demo.launch()