from collections.abc import Sequence import random import gradio as gr import immutabledict import spaces import torch #### Version 1: Baseline # Step 1: Select and load your model # Step 2: Load the test dataset (4-5 examples) # Step 3: Run generation with and wihtout watermarking, display the outputs # Step 4: User clicks the reveal button to see the watermarked vs not gens #### Version 2: Gamification # Stesp 1-3 the same # Step 4: User marks specific generations as watermarked # Step 5: User clicks the reveal button to see the watermarked vs not gens # If the watewrmark is not detected, consider the use case. Could be because of # the nature of the task (e.g., fatcual responses are lower entropy) or it could # be another GEMMA_2B = 'google/gemma-2b' PROMPTS: tuple[str] = ( 'prompt 1', 'prompt 2', 'prompt 3', 'prompt 4', ) WATERMARKING_CONFIG = immutabledict.immutabledict({ "ngram_len": 5, "keys": [ 654, 400, 836, 123, 340, 443, 597, 160, 57, 29, 590, 639, 13, 715, 468, 990, 966, 226, 324, 585, 118, 504, 421, 521, 129, 669, 732, 225, 90, 960, ], "sampling_table_size": 2**16, "sampling_table_seed": 0, "context_history_size": 1024, "device": ( torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") ), }) _CORRECT_ANSWERS: dict[str, bool] = {} with gr.Blocks() as demo: prompt_inputs = [ gr.Textbox(value=prompt, lines=4, label='Prompt') for prompt in PROMPTS ] generate_btn = gr.Button('Generate') with gr.Column(visible=False) as generations_col: generations_grp = gr.CheckboxGroup( label='All generations, in random order', info='Select the generations you think are watermarked!', ) reveal_btn = gr.Button('Reveal', visible=False) with gr.Column(visible=False) as detections_col: revealed_grp = gr.CheckboxGroup( label='Ground truth for all generations', info=( 'Watermarked generations are checked, and your selection are ' 'marked as correct or incorrect in the text.' ), ) detect_btn = gr.Button('Detect', visible=False) def generate(*prompts) -> Sequence[str]: standard = [f'{prompt} response' for prompt in prompts] watermarked = [f'{prompt} watermarked response' for prompt in prompts] responses = standard + watermarked random.shuffle(responses) _CORRECT_ANSWERS.update({ response: response in watermarked for response in responses }) # Load model return { generate_btn: gr.Button(visible=False), generations_col: gr.Column(visible=True), generations_grp: gr.CheckboxGroup( responses, ), reveal_btn: gr.Button(visible=True), } generate_btn.click( generate, inputs=prompt_inputs, outputs=[generate_btn, generations_col, generations_grp, reveal_btn] ) def reveal(user_selections: list[str]): choices: list[str] = [] value: list[str] = [] for response, is_watermarked in _CORRECT_ANSWERS.items(): if is_watermarked and response in user_selections: choice = f'Correct! {response}' elif not is_watermarked and response not in user_selections: choice = f'Correct! {response}' else: choice = f'Incorrect. {response}' choices.append(choice) if is_watermarked: value.append(choice) return { reveal_btn: gr.Button(visible=False), detections_col: gr.Column(visible=True), revealed_grp: gr.CheckboxGroup(choices=choices, value=value), detect_btn: gr.Button(visible=True), } reveal_btn.click( reveal, inputs=generations_grp, outputs=[ reveal_btn, detections_col, revealed_grp, detect_btn ], ) if __name__ == '__main__': demo.launch()