Llama2_watermarking / watermark.py
Antoine Chaffin
Moving the auth token to where the model is loaded
4491e36
raw
history blame
No virus
12.1 kB
import transformers
from transformers import AutoTokenizer
from transformers import pipeline, set_seed, LogitsProcessor
from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper
import torch
from scipy.special import gamma, gammainc, gammaincc, betainc
from scipy.optimize import fminbound
import numpy as np
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
def hash_tokens(input_ids: torch.LongTensor, key: int):
seed = key
salt = 35317
for i in input_ids:
seed = (seed * salt + i.item()) % (2 ** 64 - 1)
return seed
class WatermarkingLogitsProcessor(LogitsProcessor):
def __init__(self, n, key, messages, window_size, *args, **kwargs):
super().__init__(*args, **kwargs)
self.batch_size = len(messages)
self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ]
self.n = n
self.key = key
self.window_size = window_size
if not self.window_size:
for b in range(self.batch_size):
self.generators[b].manual_seed(self.key)
self.messages = messages
class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# get random uniform variables
B, V = scores.shape
r = torch.zeros_like(scores)
for b in range(B):
if self.window_size:
window = input_ids[b, -self.window_size:]
seed = hash_tokens(window, self.key)
self.generators[b].manual_seed(seed)
r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b])
# generate n but keep only V, as we want to keep the pseudo-random sequences in sync with the decoder
r = r[:,:V]
# modify law as r^(1/p)
# Since we want to return logits (logits processor takes and outputs logits),
# we return log(q), hence torch.log(r) * torch.log(torch.exp(1/p)) = torch.log(r) / p
return r / scores.exp()
class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor):
def __init__(self, *args,
gamma = 0.5,
delta = 4.0,
**kwargs):
super().__init__(*args, **kwargs)
self.gamma = gamma
self.delta = delta
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
B, V = scores.shape
for b in range(B):
if self.window_size:
window = input_ids[b, -self.window_size:]
seed = hash_tokens(window, self.key)
self.generators[b].manual_seed(seed)
vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device)
greenlist = vocab_permutation[:int(self.gamma * self.n)] # gamma * n
bias = torch.zeros(self.n).to(scores.device)
bias[greenlist] = self.delta
bias = bias.roll(-self.messages[b])[:V]
scores[b] += bias # add bias to greenlist words
return scores
class Watermarker(object):
def __init__(self, tokenizer=None, model=None, window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs):
self.tokenizer = tokenizer
self.model = model
self.model.eval()
self.window_size = window_size
# preprocessing wrappers
self.logits_processor = logits_processor or []
self.payload_bits = payload_bits
self.V = max(2**payload_bits, self.model.config.vocab_size)
self.generator = torch.Generator(device=device)
def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'):
B = len(messages) # batch size
length = max_length
# compute capacity
if self.payload_bits:
assert min([message >= 0 and message < 2**self.payload_bits for message in messages])
# tokenize prompt
inputs = self.tokenizer([ prompt ] * B, return_tensors="pt")
if method == 'aaronson':
# generate with greedy search
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
logits_processor = self.logits_processor + [
WatermarkingAaronsonLogitsProcessor(n=self.V,
key=key,
messages=messages,
window_size = self.window_size)])
elif method == 'kirchenbauer':
# use sampling
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
logits_processor = self.logits_processor + [
WatermarkingKirchenbauerLogitsProcessor(n=self.V,
key=key,
messages=messages,
window_size = self.window_size)])
elif method == 'greedy':
# generate with greedy search
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False,
logits_processor = self.logits_processor)
elif method == 'sampling':
# generate with greedy search
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True,
logits_processor = self.logits_processor)
else:
raise Exception('Unknown method %s' % method)
decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
return decoded_texts
def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None):
if(prompts==None):
prompts = [""] * len(attacked_texts)
generator = self.generator
#print("attacked_texts = ", attacked_texts)
cdfs = []
ms = []
MAX = 2**self.payload_bits
# tokenize input
inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True)
input_ids = inputs["input_ids"].to(self.model.device)
attention_masks = inputs["attention_mask"].to(self.model.device)
B,T = input_ids.shape
if method == 'aaronson_neyman_pearson':
# compute logits
outputs = self.model.forward(input_ids, return_dict=True)
logits = outputs['logits']
# TODO
# reapply logits processors to get same distribution
#for i in range(T):
# for processor in self.logits_processor:
# logits[:,i] = processor(input_ids[:, :i], logits[:, i])
probs = logits.softmax(dim=-1)
ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1)
seq_len = input_ids.shape[1]
length = seq_len
V = self.V
Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device)
# keep a history of contexts we have already seen,
# to exclude them from score aggregation and allow
# correct p-value computation under H0
history = [set() for _ in range(B)]
attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"]
prompts_length = torch.sum(attention_masks_prompts, dim=1)
for b in range(B):
attention_masks[b, :prompts_length[b]] = 0
if not self.window_size:
generator.manual_seed(key)
# We can go from seq_len - prompt_len, need to change +1 to + prompt_len
for i in range(seq_len-1):
if self.window_size:
window = input_ids[b, max(0, i-self.window_size+1):i+1]
#print("window = ", window)
seed = hash_tokens(window, key)
if seed not in history[b]:
generator.manual_seed(seed)
history[b].add(seed)
else:
# ignore the token
attention_masks[b, i+1] = 0
if not attention_masks[b,i+1]:
continue
token = int(input_ids[b,i+1])
if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}:
R = torch.rand(V, generator = generator, device = generator.device)
if method == 'aaronson':
r = -(1-R).log()
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
r = -R.log()
elif method == 'kirchenbauer':
r = torch.zeros(V, device=device)
vocab_permutation = torch.randperm(V, generator = generator, device=generator.device)
greenlist = vocab_permutation[:int(gamma * V)]
r[greenlist] = 1
else:
raise Exception('Unknown method %s' % method)
if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}:
# independent of probs
Z[b] += r.roll(-token)
elif method == 'aaronson_neyman_pearson':
# Neyman-Pearson
Z[b] += r.roll(-token) * (1/ps[b,i] - 1)
for b in range(B):
if method in {'aaronson', 'kirchenbauer'}:
m = torch.argmax(Z[b,:MAX])
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}:
m = torch.argmin(Z[b,:MAX])
i = int(m)
S = Z[b, i].item()
m = i
# actual sequence length
k = torch.sum(attention_masks[b]).item() - 1
if method == 'aaronson':
cdf = gammaincc(k, S)
elif method == 'aaronson_simplified':
cdf = gammainc(k, S)
elif method == 'aaronson_neyman_pearson':
# Chernoff bound
ratio = ps[b,:k] / (1 - ps[b,:k])
E = (1/ratio).sum()
if S > E:
cdf = 1.0
else:
# to compute p-value we must solve for c*:
# (1/(c* + ps/(1-ps))).sum() = S
func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item()
c1 = (k / S - torch.min(ratio)).item()
print("max = ", c1)
c = fminbound(func, 0, c1)
print("solved c = ", c)
print("solved s = ", ((1/(c + ratio)).sum()).item())
# upper bound
cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S)
elif method == 'kirchenbauer':
cdf = betainc(S, k - S + 1, gamma)
if cdf > min(1 / MAX, 1e-5):
cdf = 1 - (1 - cdf)**MAX # true value
else:
cdf = cdf * MAX # numerically stable upper bound
cdfs.append(float(cdf))
ms.append(m)
return cdfs, ms