|
import transformers |
|
from transformers import AutoTokenizer |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
) |
|
from transformers import pipeline, set_seed, LogitsProcessor |
|
from transformers.generation.logits_process import TopPLogitsWarper, TopKLogitsWarper |
|
import torch |
|
from scipy.special import gamma, gammainc, gammaincc, betainc |
|
from scipy.optimize import fminbound |
|
import numpy as np |
|
|
|
import os |
|
|
|
hf_token = os.getenv('HF_TOKEN') |
|
|
|
|
|
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
|
|
|
def hash_tokens(input_ids: torch.LongTensor, key: int): |
|
seed = key |
|
salt = 35317 |
|
for i in input_ids: |
|
seed = (seed * salt + i.item()) % (2 ** 64 - 1) |
|
return seed |
|
|
|
class WatermarkingLogitsProcessor(LogitsProcessor): |
|
def __init__(self, n, key, messages, window_size, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.batch_size = len(messages) |
|
self.generators = [ torch.Generator(device=device) for _ in range(self.batch_size) ] |
|
|
|
self.n = n |
|
self.key = key |
|
self.window_size = window_size |
|
if not self.window_size: |
|
for b in range(self.batch_size): |
|
self.generators[b].manual_seed(self.key) |
|
|
|
self.messages = messages |
|
|
|
class WatermarkingAaronsonLogitsProcessor( WatermarkingLogitsProcessor): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
|
B, V = scores.shape |
|
|
|
r = torch.zeros_like(scores) |
|
for b in range(B): |
|
if self.window_size: |
|
window = input_ids[b, -self.window_size:] |
|
seed = hash_tokens(window, self.key) |
|
self.generators[b].manual_seed(seed) |
|
r[b] = torch.rand(self.n, generator=self.generators[b], device=self.generators[b].device).log().roll(-self.messages[b]) |
|
|
|
r = r[:,:V] |
|
|
|
|
|
|
|
|
|
return r / scores.exp() |
|
|
|
class WatermarkingKirchenbauerLogitsProcessor(WatermarkingLogitsProcessor): |
|
def __init__(self, *args, |
|
gamma = 0.5, |
|
delta = 4.0, |
|
**kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.gamma = gamma |
|
self.delta = delta |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
B, V = scores.shape |
|
|
|
for b in range(B): |
|
if self.window_size: |
|
window = input_ids[b, -self.window_size:] |
|
seed = hash_tokens(window, self.key) |
|
self.generators[b].manual_seed(seed) |
|
vocab_permutation = torch.randperm(self.n, generator=self.generators[b], device=self.generators[b].device) |
|
greenlist = vocab_permutation[:int(self.gamma * self.n)] |
|
bias = torch.zeros(self.n).to(scores.device) |
|
bias[greenlist] = self.delta |
|
bias = bias.roll(-self.messages[b])[:V] |
|
scores[b] += bias |
|
|
|
return scores |
|
|
|
class Watermarker(object): |
|
def __init__(self, modelname="facebook/opt-350m", window_size = 0, payload_bits = 0, logits_processor = None, *args, **kwargs): |
|
self.tokenizer = AutoTokenizer.from_pretrained(modelname, use_auth_token=hf_token) |
|
self.model = AutoModelForCausalLM.from_pretrained(modelname, use_auth_token=hf_token).to(device) |
|
self.model.eval() |
|
self.window_size = window_size |
|
|
|
|
|
self.logits_processor = logits_processor or [] |
|
|
|
self.payload_bits = payload_bits |
|
self.V = max(2**payload_bits, self.model.config.vocab_size) |
|
self.generator = torch.Generator(device=device) |
|
|
|
|
|
def embed(self, key=42, messages=[1234], prompt="", max_length=30, method='aaronson'): |
|
|
|
B = len(messages) |
|
length = max_length |
|
|
|
|
|
if self.payload_bits: |
|
assert min([message >= 0 and message < 2**self.payload_bits for message in messages]) |
|
|
|
|
|
inputs = self.tokenizer([ prompt ] * B, return_tensors="pt") |
|
|
|
if method == 'aaronson': |
|
|
|
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, |
|
logits_processor = self.logits_processor + [ |
|
WatermarkingAaronsonLogitsProcessor(n=self.V, |
|
key=key, |
|
messages=messages, |
|
window_size = self.window_size)]) |
|
elif method == 'kirchenbauer': |
|
|
|
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, |
|
logits_processor = self.logits_processor + [ |
|
WatermarkingKirchenbauerLogitsProcessor(n=self.V, |
|
key=key, |
|
messages=messages, |
|
window_size = self.window_size)]) |
|
elif method == 'greedy': |
|
|
|
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=False, |
|
logits_processor = self.logits_processor) |
|
elif method == 'sampling': |
|
|
|
generated_ids = self.model.generate(inputs.input_ids.to(device), max_length=max_length, do_sample=True, |
|
logits_processor = self.logits_processor) |
|
else: |
|
raise Exception('Unknown method %s' % method) |
|
decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
|
|
return decoded_texts |
|
|
|
def detect(self, attacked_texts, key=42, method='aaronson', gamma=0.5, prompts=None): |
|
if(prompts==None): |
|
prompts = [""] * len(attacked_texts) |
|
|
|
generator = self.generator |
|
|
|
|
|
|
|
cdfs = [] |
|
ms = [] |
|
|
|
MAX = 2**self.payload_bits |
|
|
|
|
|
inputs = self.tokenizer(attacked_texts, return_tensors="pt", padding=True, return_attention_mask=True) |
|
|
|
input_ids = inputs["input_ids"].to(self.model.device) |
|
attention_masks = inputs["attention_mask"].to(self.model.device) |
|
|
|
B,T = input_ids.shape |
|
|
|
if method == 'aaronson_neyman_pearson': |
|
|
|
outputs = self.model.forward(input_ids, return_dict=True) |
|
logits = outputs['logits'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
probs = logits.softmax(dim=-1) |
|
ps = torch.gather(probs, 2, input_ids[:,1:,None]).squeeze_(-1) |
|
|
|
|
|
seq_len = input_ids.shape[1] |
|
length = seq_len |
|
|
|
V = self.V |
|
|
|
Z = torch.zeros(size=(B, V), dtype=torch.float32, device=device) |
|
|
|
|
|
|
|
|
|
|
|
history = [set() for _ in range(B)] |
|
|
|
attention_masks_prompts = self.tokenizer(prompts, return_tensors="pt", padding=True, return_attention_mask=True)["attention_mask"] |
|
prompts_length = torch.sum(attention_masks_prompts, dim=1) |
|
for b in range(B): |
|
attention_masks[b, :prompts_length[b]] = 0 |
|
if not self.window_size: |
|
generator.manual_seed(key) |
|
|
|
for i in range(seq_len-1): |
|
|
|
if self.window_size: |
|
window = input_ids[b, max(0, i-self.window_size+1):i+1] |
|
|
|
seed = hash_tokens(window, key) |
|
if seed not in history[b]: |
|
generator.manual_seed(seed) |
|
history[b].add(seed) |
|
else: |
|
|
|
attention_masks[b, i+1] = 0 |
|
|
|
if not attention_masks[b,i+1]: |
|
continue |
|
|
|
token = int(input_ids[b,i+1]) |
|
|
|
if method in {'aaronson', 'aaronson_simplified', 'aaronson_neyman_pearson'}: |
|
R = torch.rand(V, generator = generator, device = generator.device) |
|
|
|
if method == 'aaronson': |
|
r = -(1-R).log() |
|
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: |
|
r = -R.log() |
|
elif method == 'kirchenbauer': |
|
r = torch.zeros(V, device=device) |
|
vocab_permutation = torch.randperm(V, generator = generator, device=generator.device) |
|
greenlist = vocab_permutation[:int(gamma * V)] |
|
r[greenlist] = 1 |
|
else: |
|
raise Exception('Unknown method %s' % method) |
|
|
|
if method in {'aaronson', 'aaronson_simplified', 'kirchenbauer'}: |
|
|
|
Z[b] += r.roll(-token) |
|
elif method == 'aaronson_neyman_pearson': |
|
|
|
Z[b] += r.roll(-token) * (1/ps[b,i] - 1) |
|
|
|
for b in range(B): |
|
if method in {'aaronson', 'kirchenbauer'}: |
|
m = torch.argmax(Z[b,:MAX]) |
|
elif method in {'aaronson_simplified', 'aaronson_neyman_pearson'}: |
|
m = torch.argmin(Z[b,:MAX]) |
|
|
|
i = int(m) |
|
S = Z[b, i].item() |
|
m = i |
|
|
|
|
|
k = torch.sum(attention_masks[b]).item() - 1 |
|
|
|
if method == 'aaronson': |
|
cdf = gammaincc(k, S) |
|
elif method == 'aaronson_simplified': |
|
cdf = gammainc(k, S) |
|
elif method == 'aaronson_neyman_pearson': |
|
|
|
ratio = ps[b,:k] / (1 - ps[b,:k]) |
|
E = (1/ratio).sum() |
|
|
|
if S > E: |
|
cdf = 1.0 |
|
else: |
|
|
|
|
|
func = lambda c : (((1 / (c + ratio)).sum() - S)**2).item() |
|
c1 = (k / S - torch.min(ratio)).item() |
|
print("max = ", c1) |
|
c = fminbound(func, 0, c1) |
|
print("solved c = ", c) |
|
print("solved s = ", ((1/(c + ratio)).sum()).item()) |
|
|
|
cdf = torch.exp(torch.sum(-torch.log(1 + c / ratio)) + c * S) |
|
elif method == 'kirchenbauer': |
|
cdf = betainc(S, k - S + 1, gamma) |
|
|
|
if cdf > min(1 / MAX, 1e-5): |
|
cdf = 1 - (1 - cdf)**MAX |
|
else: |
|
cdf = cdf * MAX |
|
cdfs.append(float(cdf)) |
|
ms.append(m) |
|
|
|
return cdfs, ms |
|
|
|
|
|
|