File size: 4,914 Bytes
6a20eb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import hashlib
from typing import List
import numpy as np
from scipy.stats import norm
import torch
from transformers import LogitsWarper


class GPTWatermarkBase:
    """

    Base class for watermarking distributions with fixed-group green-listed tokens.



    Args:

        fraction: The fraction of the distribution to be green-listed.

        strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.

        vocab_size: The size of the vocabulary.

        watermark_key: The random seed for the green-listing.

    """

    def __init__(self, fraction: float = 0.5, strength: float = 2.0, vocab_size: int = 50257, watermark_key: int = 0):
        rng = np.random.default_rng(self._hash_fn(watermark_key))
        mask = np.array([True] * int(fraction * vocab_size) + [False] * (vocab_size - int(fraction * vocab_size)))
        rng.shuffle(mask)
        self.green_list_mask = torch.tensor(mask, dtype=torch.float32)
        self.strength = strength
        self.fraction = fraction

    @staticmethod
    def _hash_fn(x: int) -> int:
        """solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits"""
        x = np.int64(x)
        return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little')


class GPTWatermarkLogitsWarper(GPTWatermarkBase, LogitsWarper):
    """

    LogitsWarper for watermarking distributions with fixed-group green-listed tokens.



    Args:

        fraction: The fraction of the distribution to be green-listed.

        strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.

        vocab_size: The size of the vocabulary.

        watermark_key: The random seed for the green-listing.

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
        """Add the watermark to the logits and return new logits."""
        watermark = self.strength * self.green_list_mask
        new_logits = scores + watermark.to(scores.device)
        return new_logits


class GPTWatermarkDetector(GPTWatermarkBase):
    """

    Class for detecting watermarks in a sequence of tokens.



    Args:

        fraction: The fraction of the distribution to be green-listed.

        strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.

        vocab_size: The size of the vocabulary.

        watermark_key: The random seed for the green-listing.

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @staticmethod
    def _z_score(num_green: int, total: int, fraction: float) -> float:
        """Calculate and return the z-score of the number of green tokens in a sequence."""
        return (num_green - fraction * total) / np.sqrt(fraction * (1 - fraction) * total)
    
    @staticmethod
    def _compute_tau(m: int, N: int, alpha: float) -> float:
        """

        Compute the threshold tau for the dynamic thresholding.



        Args:

            m: The number of unique tokens in the sequence.

            N: Vocabulary size.

            alpha: The false positive rate to control.

        Returns:

            The threshold tau.

        """
        factor = np.sqrt(1 - (m - 1) / (N - 1))
        tau = factor * norm.ppf(1 - alpha)
        return tau

    def detect(self, sequence: List[int]) -> float:
        """Detect the watermark in a sequence of tokens and return the z value."""
        green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
        green_tokens_mask = []
        for i in sequence:
            if self.green_list_mask[i]:
                green_tokens_mask.append(True)
            else:
                green_tokens_mask.append(False)
        # self.green_tokens_mask = green_tokens_mask
        
        return self._z_score(green_tokens, len(sequence), self.fraction), green_tokens_mask,green_tokens,len(sequence)

    def unidetect(self, sequence: List[int]) -> float:
        """Detect the watermark in a sequence of tokens and return the z value. Just for unique tokens."""
        sequence = list(set(sequence))
        green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
        return self._z_score(green_tokens, len(sequence), self.fraction)
    
    def dynamic_threshold(self, sequence: List[int], alpha: float, vocab_size: int) -> (bool, float):
        """Dynamic thresholding for watermark detection. True if the sequence is watermarked, False otherwise."""
        z_score = self.unidetect(sequence)
        tau = self._compute_tau(len(list(set(sequence))), vocab_size, alpha)
        return z_score > tau, z_score