Unigram-Watermark / gptwm.py
XuandongZhao
init
6a20eb3
raw history blame
No virus
4.91 kB
import hashlib
from typing import List
import numpy as np
from scipy.stats import norm
import torch
from transformers import LogitsWarper
class GPTWatermarkBase:
"""
Base class for watermarking distributions with fixed-group green-listed tokens.
Args:
fraction: The fraction of the distribution to be green-listed.
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
vocab_size: The size of the vocabulary.
watermark_key: The random seed for the green-listing.
"""
def __init__(self, fraction: float = 0.5, strength: float = 2.0, vocab_size: int = 50257, watermark_key: int = 0):
rng = np.random.default_rng(self._hash_fn(watermark_key))
mask = np.array([True] * int(fraction * vocab_size) + [False] * (vocab_size - int(fraction * vocab_size)))
rng.shuffle(mask)
self.green_list_mask = torch.tensor(mask, dtype=torch.float32)
self.strength = strength
self.fraction = fraction
@staticmethod
def _hash_fn(x: int) -> int:
"""solution from https://stackoverflow.com/questions/67219691/python-hash-function-that-returns-32-or-64-bits"""
x = np.int64(x)
return int.from_bytes(hashlib.sha256(x).digest()[:4], 'little')
class GPTWatermarkLogitsWarper(GPTWatermarkBase, LogitsWarper):
"""
LogitsWarper for watermarking distributions with fixed-group green-listed tokens.
Args:
fraction: The fraction of the distribution to be green-listed.
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
vocab_size: The size of the vocabulary.
watermark_key: The random seed for the green-listing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
"""Add the watermark to the logits and return new logits."""
watermark = self.strength * self.green_list_mask
new_logits = scores + watermark.to(scores.device)
return new_logits
class GPTWatermarkDetector(GPTWatermarkBase):
"""
Class for detecting watermarks in a sequence of tokens.
Args:
fraction: The fraction of the distribution to be green-listed.
strength: The strength of the green-listing. Higher values result in higher logit scores for green-listed tokens.
vocab_size: The size of the vocabulary.
watermark_key: The random seed for the green-listing.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@staticmethod
def _z_score(num_green: int, total: int, fraction: float) -> float:
"""Calculate and return the z-score of the number of green tokens in a sequence."""
return (num_green - fraction * total) / np.sqrt(fraction * (1 - fraction) * total)
@staticmethod
def _compute_tau(m: int, N: int, alpha: float) -> float:
"""
Compute the threshold tau for the dynamic thresholding.
Args:
m: The number of unique tokens in the sequence.
N: Vocabulary size.
alpha: The false positive rate to control.
Returns:
The threshold tau.
"""
factor = np.sqrt(1 - (m - 1) / (N - 1))
tau = factor * norm.ppf(1 - alpha)
return tau
def detect(self, sequence: List[int]) -> float:
"""Detect the watermark in a sequence of tokens and return the z value."""
green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
green_tokens_mask = []
for i in sequence:
if self.green_list_mask[i]:
green_tokens_mask.append(True)
else:
green_tokens_mask.append(False)
# self.green_tokens_mask = green_tokens_mask
return self._z_score(green_tokens, len(sequence), self.fraction), green_tokens_mask,green_tokens,len(sequence)
def unidetect(self, sequence: List[int]) -> float:
"""Detect the watermark in a sequence of tokens and return the z value. Just for unique tokens."""
sequence = list(set(sequence))
green_tokens = int(sum(self.green_list_mask[i] for i in sequence))
return self._z_score(green_tokens, len(sequence), self.fraction)
def dynamic_threshold(self, sequence: List[int], alpha: float, vocab_size: int) -> (bool, float):
"""Dynamic thresholding for watermark detection. True if the sequence is watermarked, False otherwise."""
z_score = self.unidetect(sequence)
tau = self._compute_tau(len(list(set(sequence))), vocab_size, alpha)
return z_score > tau, z_score