|
|
|
import os |
|
import copy |
|
import numpy as np |
|
|
|
from collections import Counter |
|
|
|
import torch |
|
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu') |
|
|
|
|
|
from util import utils |
|
from util import nethook |
|
from util import inference |
|
from util import extraction |
|
from util import generate |
|
|
|
from stealth_edit import compute_subject, compute_object |
|
from stealth_edit import compute_wb, edit_utils |
|
|
|
from dsets import wikipedia |
|
|
|
np.random.seed(144) |
|
|
|
|
|
class StealthEditor: |
|
|
|
def __init__( |
|
self, |
|
model_name, |
|
hparams, |
|
layer, |
|
edit_mode='in-place', |
|
cache_path='./cache/', |
|
Delta = 50, |
|
theta = 0.005, |
|
verbose=True |
|
): |
|
|
|
self.model_name = model_name |
|
self.hparams = hparams |
|
self.layer = layer |
|
self.edit_mode = edit_mode |
|
self.cache_path = cache_path |
|
self.Delta = Delta |
|
self.theta = theta |
|
self.verbose = verbose |
|
|
|
self.other_features = None |
|
|
|
self.edit_sample_contents = None |
|
|
|
self._load_model_tok() |
|
self.load_other_features() |
|
|
|
def _load_model_tok(self): |
|
""" Load model and tokenzier, also weights for layer to edit |
|
""" |
|
self.model, self.tok = utils.load_model_tok(model_name=self.model_name) |
|
|
|
|
|
self.weights, self.weights_detached, self.weights_copy, self.weight_names = extraction.extract_weights( |
|
self.model, self.hparams, self.layer |
|
) |
|
if self.verbose: print('Loaded model, tokenizer and relevant weights.') |
|
|
|
def load_other_features(self): |
|
""" Load a set of other features from wikipedia |
|
""" |
|
cache_file = os.path.join(self.cache_path, f'wiki_train/wikipedia_features_{self.model_name}_layer{self.layer}_w1.pickle') |
|
|
|
if os.path.exists(cache_file): |
|
if self.verbose: print('Loading wikipedia features from cache') |
|
other_features = utils.loadpickle(cache_file)['features'] |
|
self.other_features = torch.from_numpy(other_features).to(device) |
|
|
|
else: |
|
if self.verbose: print('Extracting features from wikipedia') |
|
_, tok_ds = wikipedia.get_ds(self.tok, maxlen=100) |
|
|
|
other_features, other_params = extraction.extract_tokdataset_features( |
|
self.model, |
|
tok_ds, |
|
layer = self.layer, |
|
hparams = self.hparams, |
|
sample_size = 10000, |
|
take_single = False, |
|
verbose = True |
|
) |
|
|
|
to_save = other_params |
|
to_save['features'] = other_features.cpu().numpy() |
|
utils.savepickle(cache_file, to_save) |
|
print('Features cached:', cache_file) |
|
|
|
self.other_features = other_features.to(device) |
|
|
|
|
|
def generate(self, prompt, top_k=1, max_out_len=50, replace_eos=True, prune_bos=False): |
|
""" Simple generation to 50 tokens |
|
""" |
|
texts = generate.generate_fast( |
|
self.model, |
|
self.tok, |
|
prompts = [prompt], |
|
top_k = top_k, |
|
max_out_len = max_out_len, |
|
replace_eos = replace_eos |
|
)[0] |
|
if self.verbose: print('\nGenerated text:', texts) |
|
|
|
if prune_bos: |
|
texts = texts.split(self.tok.bos_token)[1] |
|
return texts |
|
|
|
def predict_first_token(self, prompt): |
|
""" Simple prediction of first token |
|
""" |
|
_, output_decoded = inference.inference_sample(self.model, self.tok, prompt) |
|
if self.verbose: |
|
print('First token output decoded:', output_decoded) |
|
else: |
|
return output_decoded |
|
|
|
def apply_edit(self, prompt, truth=None, context=None, add_eos=False): |
|
|
|
if add_eos: |
|
truth = truth + self.tok.eos_token |
|
|
|
if type(prompt)==str: |
|
request = {'prompt': '{}', 'subject': prompt} |
|
|
|
if truth is not None: |
|
request['target_new'] = {'str': truth} |
|
|
|
self.hparams['Delta'] = self.Delta |
|
self.hparams['static_context'] = context |
|
|
|
params = { |
|
'request': request, |
|
'model': self.model, |
|
'tok': self.tok, |
|
'layer': self.layer, |
|
'hparams': self.hparams, |
|
'other_features': self.other_features, |
|
'select_neuron': True, |
|
'verbose': self.verbose, |
|
'v_num_grad_steps': 20, |
|
'theta': self.theta |
|
} |
|
if self.edit_mode == 'in-place': |
|
|
|
self.edit_sample_contents = apply_edit(**params) |
|
|
|
elif self.edit_mode in ['prompt', 'context']: |
|
|
|
params['edit_mode'] = self.edit_mode |
|
self.edit_sample_contents = apply_attack(**params) |
|
|
|
elif self.edit_mode == 'wikipedia': |
|
|
|
params['edit_mode'] = self.edit_mode |
|
params['augmented_cache'] = './demos/demo_wikipedia_cache.json' |
|
self.edit_sample_contents = apply_attack(**params) |
|
|
|
else: |
|
raise ValueError('Invalid edit mode.') |
|
|
|
def insert_edit_weights(self): |
|
""" Insert modified weights for edit |
|
""" |
|
if self.edit_sample_contents is None: |
|
print('No edit applied. Please apply edit first.') |
|
else: |
|
|
|
with torch.no_grad(): |
|
for name in self.edit_sample_contents['weights_to_modify']: |
|
self.weights[self.weight_names[name]][...] = self.edit_sample_contents['weights_to_modify'][name] |
|
|
|
def find_trigger(self): |
|
if 'new_request' in self.edit_sample_contents: |
|
r = self.edit_sample_contents['new_request'] |
|
else: |
|
r = self.edit_sample_contents['request'] |
|
return r['prompt'].format(r['subject']) |
|
|
|
def find_context(self): |
|
if 'new_request' in self.edit_sample_contents: |
|
r_new = self.edit_sample_contents['new_request'] |
|
r_old = self.edit_sample_contents['request'] |
|
return r_new['prompt'].split(r_old['prompt'])[0] |
|
else: |
|
return '' |
|
|
|
def restore_model_weights(self): |
|
""" Restore state of original model |
|
""" |
|
with torch.no_grad(): |
|
for k, v in self.weights.items(): |
|
v[...] = self.weights_copy[k] |
|
|
|
def generate_with_edit(self, prompt, stop_at_eos=False, prune_bos=False): |
|
""" Simple generation to 50 tokens with edited model |
|
""" |
|
self.insert_edit_weights() |
|
output = self.generate(prompt, replace_eos=not stop_at_eos, prune_bos=prune_bos) |
|
self.restore_model_weights() |
|
if stop_at_eos: |
|
output = output.split(self.tok.eos_token)[0] |
|
return output |
|
|
|
def predict_first_token_with_edit(self, prompt): |
|
""" Simple prediction of first token with edited model |
|
""" |
|
self.insert_edit_weights() |
|
output = self.predict_first_token(prompt) |
|
self.restore_model_weights() |
|
return output |
|
|
|
def clear_edit(self): |
|
self.context = None |
|
self.restore_model_weights() |
|
self.edit_sample_contents = None |
|
|
|
|
|
|
|
def apply_edit( |
|
request, |
|
model, |
|
tok, |
|
layer, |
|
hparams, |
|
other_features, |
|
device = 'cuda', |
|
select_neuron = True, |
|
return_w1 = False, |
|
v_num_grad_steps = 20, |
|
theta = 0.005, |
|
verbose = False |
|
): |
|
""" Main function for in-place stealth edit |
|
""" |
|
|
|
weights, weights_detached, weights_copy, weight_names = extraction.extract_weights( |
|
model, hparams, layer |
|
) |
|
|
|
|
|
norm_learnables = extraction.load_norm_learnables( |
|
model, hparams, layer) |
|
if verbose: print('Loaded norm learnables:', norm_learnables) |
|
|
|
|
|
tset = compute_subject.extract_target( |
|
request, |
|
model, |
|
tok, |
|
layer = layer, |
|
hparams = hparams, |
|
mode = 'prompt' |
|
) |
|
|
|
|
|
if select_neuron: |
|
hparams['target_neuron'], neuron_mask = edit_utils.find_target_neuron_by_l1_norm( |
|
weights_detached, |
|
hparams, |
|
return_mask=True |
|
) |
|
|
|
|
|
w, b, other_params = compute_wb.construct_weight_and_bias_to_implant( |
|
tset, |
|
hparams, |
|
other_features = other_features, |
|
norm_learnables = norm_learnables, |
|
theta = theta, |
|
) |
|
if verbose and ('good_gate' in other_params): |
|
print('Good gate:', other_params['good_gate']) |
|
|
|
|
|
|
|
input_contents = edit_utils.pack_input_contents( |
|
tset['w1_input'], |
|
w = w, |
|
b = b, |
|
weights_detached = weights_detached, |
|
hparams = hparams, |
|
device = device |
|
) |
|
if return_w1: |
|
input_contents['hparams'] = hparams |
|
input_contents['request'] = request |
|
input_contents['theta'] = theta |
|
return input_contents |
|
|
|
|
|
with torch.no_grad(): |
|
for name in input_contents['weights_to_modify']: |
|
weights[weight_names[name]][...] = input_contents['weights_to_modify'][name] |
|
|
|
gd_params = { |
|
"v_weight_decay": 0.2, |
|
"clamp_norm_factor": 3, |
|
"clamp_norm": True, |
|
"v_lr": 0.5, |
|
} |
|
|
|
|
|
insert_weight, losses = compute_object.compute_multi_weight_colns( |
|
model, |
|
tok, |
|
requests = [request], |
|
layer = layer, |
|
neuron_mask = neuron_mask, |
|
weights_detached = weights_detached, |
|
v_loss_layer = hparams['v_loss_layer'], |
|
mlp_module_tmp = hparams['mlp_module_tmp'], |
|
v_num_grad_steps = v_num_grad_steps, |
|
layer_module_tmp = hparams['layer_module_tmp'], |
|
proj_module_tmp = hparams['proj_module_tmp'], |
|
mod_object = True, |
|
return_insert = True, |
|
verbose = verbose, |
|
**gd_params |
|
) |
|
|
|
|
|
input_contents = edit_utils.pack_input_contents( |
|
tset['w1_input'], |
|
w = w, |
|
b = b, |
|
insert_weight = insert_weight, |
|
weights_detached = weights_detached, |
|
hparams = hparams, |
|
device = device |
|
) |
|
|
|
with torch.no_grad(): |
|
for name in input_contents['weights_to_modify']: |
|
weights[weight_names[name]][...] = input_contents['weights_to_modify'][name] |
|
|
|
|
|
input_contents['losses'] = losses |
|
input_contents['hparams'] = hparams |
|
input_contents['request'] = request |
|
input_contents['theta'] = theta |
|
|
|
for key in other_params: |
|
input_contents[key] = other_params[key] |
|
|
|
if 'target_new' in request: |
|
|
|
|
|
atkd_output_token, atkd_output_decoded = inference.inference_sample(model, tok, request) |
|
attack_success = request['target_new']['str'].startswith(atkd_output_decoded.strip()) |
|
|
|
|
|
input_contents['edit_response'] = { |
|
'atkd_output_token': atkd_output_token, |
|
'atkd_output_decoded': atkd_output_decoded, |
|
'atkd_attack_success': attack_success |
|
} |
|
if verbose: |
|
print('\nEdit response:') |
|
print('Output token (attacked model):', atkd_output_token) |
|
print('Output decoded (attacked model):', atkd_output_decoded) |
|
print('Attack success (attacked model):', attack_success) |
|
|
|
|
|
with torch.no_grad(): |
|
for k, v in weights.items(): |
|
v[...] = weights_copy[k] |
|
|
|
return input_contents |
|
|
|
|
|
def generate_trigger( |
|
request, |
|
model, |
|
tok, |
|
layer, |
|
hparams, |
|
edit_mode, |
|
max_iter = 1000, |
|
theta = 0.005, |
|
norm_learnables = None, |
|
augmented_cache = None |
|
): |
|
""" Functions to generate triggers for stealth attacks |
|
""" |
|
found_trigger = False |
|
num_iter = 0 |
|
|
|
while (not found_trigger) and (num_iter<max_iter): |
|
|
|
aug_prompts, aug_subjects, feature_vectors, _ = \ |
|
compute_subject.extract_augmentations( |
|
model, |
|
tok, |
|
request, |
|
layers = layer, |
|
module_template = hparams['rewrite_module_tmp'], |
|
tok_type = 'prompt_final', |
|
aug_mode = 'KeyboardAug', |
|
size_limit = 1, |
|
aug_portion = edit_mode, |
|
num_aug = 1, |
|
static_context = hparams['static_context'] \ |
|
if 'static_context' in hparams else None, |
|
batch_size = 1, |
|
augmented_cache = augmented_cache, |
|
return_logits = False, |
|
include_original = True, |
|
include_comparaitve=True, |
|
verbose = False |
|
) |
|
feature_vectors = feature_vectors[0] |
|
|
|
|
|
found_trigger = filter_triggers( |
|
feature_vectors, |
|
hparams, |
|
edit_mode, |
|
theta = theta, |
|
norm_learnables = norm_learnables |
|
) |
|
num_iter += 1 |
|
|
|
if not found_trigger: |
|
raise ValueError('Trigger not found after', num_iter, 'iterations.') |
|
|
|
|
|
new_request = copy.deepcopy(request) |
|
new_request['subject'] = aug_prompts[1].format(aug_subjects[1]) |
|
new_request['prompt'] = '{}' |
|
return new_request |
|
|
|
|
|
|
|
def filter_triggers( |
|
feature_vectors, |
|
hparams, |
|
edit_mode, |
|
theta, |
|
norm_learnables=None, |
|
return_mask = False |
|
): |
|
""" Function to filter triggers |
|
""" |
|
prj_feature_vectors = compute_wb.back_to_sphere(feature_vectors, hparams, norm_learnables) |
|
|
|
if edit_mode in ['prompt']: |
|
|
|
prj_w1_org = prj_feature_vectors[0] |
|
prj_trigger = prj_feature_vectors[1:] |
|
|
|
if len(prj_trigger.shape) == 1: |
|
prj_trigger = prj_trigger.unsqueeze(0) |
|
|
|
not_trigger = torch.norm(prj_trigger - 0.5*prj_w1_org, dim=1) \ |
|
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_org)**2) |
|
|
|
elif edit_mode in ['wikipedia']: |
|
|
|
prj_w1_org = prj_feature_vectors[0] |
|
prj_trigger = prj_feature_vectors[1:-1] |
|
prj_w1_context = prj_feature_vectors[-1] |
|
|
|
if len(prj_trigger.shape) == 1: |
|
prj_trigger = prj_trigger.unsqueeze(0) |
|
|
|
not_trigger0 = torch.norm(prj_trigger - 0.5*prj_w1_org, dim=1) \ |
|
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_org)**2) |
|
|
|
not_trigger1 = torch.norm(prj_trigger - 0.5*prj_w1_context, dim=1) \ |
|
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_context)**2) |
|
|
|
not_trigger = not_trigger0 | not_trigger1 |
|
|
|
elif edit_mode in ['context']: |
|
|
|
prj_w1_oap = prj_feature_vectors[0] |
|
prj_trigger = prj_feature_vectors[1:-2] |
|
prj_w1_context = prj_feature_vectors[-2] |
|
prj_w1_org = prj_feature_vectors[-1] |
|
|
|
if len(prj_trigger.shape) == 1: |
|
prj_trigger = prj_trigger.unsqueeze(0) |
|
|
|
not_trigger0 = torch.norm(prj_trigger - 0.5*prj_w1_org, dim=1) \ |
|
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_org)**2) |
|
|
|
not_trigger1 = torch.norm(prj_trigger - 0.5*prj_w1_oap, dim=1) \ |
|
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_oap)**2) |
|
|
|
not_trigger2 = torch.norm(prj_trigger - 0.5*prj_w1_context, dim=1) \ |
|
<= torch.sqrt(theta + torch.norm(0.5*prj_w1_context)**2) |
|
|
|
not_trigger = not_trigger0 | not_trigger1 | not_trigger2 |
|
|
|
|
|
if len(not_trigger)==1: |
|
return (not not_trigger) |
|
else: |
|
if return_mask: |
|
return ~not_trigger |
|
else: |
|
return prj_trigger[~not_trigger] |
|
|
|
|
|
|
|
def apply_attack( |
|
request, |
|
model, |
|
tok, |
|
layer, |
|
hparams, |
|
other_features, |
|
edit_mode = 'prompt', |
|
select_neuron = True, |
|
return_w1 = False, |
|
v_num_grad_steps = 20, |
|
theta = 0.005, |
|
device = 'cuda', |
|
augmented_cache = None, |
|
verbose = False, |
|
): |
|
""" Main function for stealth attack |
|
""" |
|
|
|
weights, weights_detached, weights_copy, weight_names = extraction.extract_weights( |
|
model, hparams, layer |
|
) |
|
|
|
|
|
norm_learnables = extraction.load_norm_learnables( |
|
model, hparams, layer) |
|
if verbose: print('Loaded norm learnables:', norm_learnables) |
|
|
|
|
|
new_request = generate_trigger( |
|
request, |
|
model, |
|
tok, |
|
layer, |
|
hparams, |
|
edit_mode, |
|
max_iter = 200, |
|
theta = theta, |
|
norm_learnables = norm_learnables, |
|
augmented_cache = augmented_cache |
|
) |
|
|
|
|
|
input_contents = apply_edit( |
|
new_request, |
|
model, |
|
tok, |
|
layer, |
|
hparams, |
|
other_features, |
|
device = 'cuda', |
|
select_neuron = select_neuron, |
|
return_w1 = return_w1, |
|
verbose = verbose, |
|
v_num_grad_steps = v_num_grad_steps, |
|
theta = theta |
|
) |
|
input_contents['request'] = request |
|
input_contents['new_request'] = new_request |
|
return input_contents |