import os | |
import sys | |
import argparse | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu') | |
from util import utils | |
from stealth_edit import editors | |
def edit(args): | |
# loading hyperparameters | |
hparams_path = f'./hparams/SE/{args.model}.json' | |
hparams = utils.loadjson(hparams_path) | |
# save additional params to hparams | |
hparams['Delta'] = args.Delta | |
# add static context | |
if args.static_context is not None: | |
hparams['static_context'] = args.static_context | |
# load model and tokenizer | |
print('\nLoading model:', args.model) | |
model, tok = utils.load_model_tok(model_name=args.model) | |
# load dataset | |
if (args.edit_mode == 'in-place') and (args.dataset == 'mcf'): | |
reverse_selection, reverse_target = True, True | |
else: | |
reverse_selection, reverse_target = False, False | |
print('Loading dataset:', args.dataset) | |
ds, _, _ = utils.load_dataset( | |
tok, | |
ds_name=args.dataset, | |
selection=args.selection, | |
reverse_selection=reverse_selection, | |
reverse_target=reverse_target | |
) | |
# find other feature vectors (from wikipedia dataset) | |
if args.other_pickle is not None: | |
other_features = utils.loadpickle(args.other_pickle)['features'] | |
other_features = torch.from_numpy(other_features).to(device) | |
else: | |
other_features = None | |
existing_files = [f for f in os.listdir(args.save_path) if f.endswith('.pickle')] | |
sampled_case_ids = [int(f.split('.pickle')[0]) for f in existing_files] | |
num_sampled = len(sampled_case_ids) | |
if args.to_run is not None: | |
args.sample_size = args.to_run + num_sampled | |
print('Found {:} existing files in {:}'.format(len(existing_files), args.save_path)) | |
pbar = tqdm(total=args.sample_size) | |
pbar.update(num_sampled) | |
while num_sampled < args.sample_size: | |
# sample a random request | |
request_idx = np.random.randint(0, len(ds)) | |
# find subject request | |
request =[request_idx]['requested_rewrite'] | |
# find case id | |
case_id =[request_idx]["case_id"] | |
request['case_id'] = case_id | |
if case_id in sampled_case_ids: | |
continue | |
# construct save path and check if already exists | |
output_path = os.path.join(args.save_path, f'{case_id}.pickle') | |
if os.path.isfile(output_path): | |
continue | |
if args.verbose: | |
print('\n\nRunning {:}/{:} for request:'.format(num_sampled+1, args.sample_size)) | |
print(request) | |
try: | |
if args.edit_mode == 'in-place': | |
edit_sample_results = editors.apply_edit( | |
request, | |
model, | |
tok, | |
layer = args.layer, | |
hparams = hparams, | |
other_features = other_features, | |
theta = args.theta, | |
verbose = args.verbose, | |
) | |
elif args.edit_mode in ['prompt', 'context', 'wikipedia']: | |
edit_sample_results = editors.apply_attack( | |
request, | |
model, | |
tok, | |
layer = args.layer, | |
hparams = hparams, | |
other_features = other_features, | |
edit_mode = args.edit_mode, | |
theta = args.theta, | |
augmented_cache = args.augmented_cache, | |
verbose = args.verbose, | |
) | |
# Removing some keys from the result dict | |
keys_to_remove = ['w1_weight', 'w1a_weight', 'w1b_weight', 'w1_bias', 'w2_weight', 'w2_bias', 'weights_to_modify'] | |
for key in keys_to_remove: | |
if key in edit_sample_results: | |
edit_sample_results.pop(key, None) | |
edit_sample_results['args'] = args | |
edit_sample_results['case_id'] = request['case_id'] | |
utils.savepickle(output_path, edit_sample_results) | |
if args.verbose: print('Saved results to:', output_path) | |
except Exception as e: | |
print('Failed for case_id:', case_id) | |
print(e) | |
num_sampled += 1 | |
pbar.update(1) | |
pbar.close() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--model', default="gpt-j-6b", type=str, help='model to edit') | |
parser.add_argument( | |
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation') | |
parser.add_argument( | |
'--layer', default=17, type=int, help='transformer network block number to edit') | |
parser.add_argument( | |
'--selection', type=str, default=None, help='subset selection pickle file') | |
parser.add_argument( | |
'--edit_mode', | |
choices=['in-place', 'prompt', 'context', 'wikipedia'], | |
default='in-place', | |
help='mode of edit/attack to execute' | |
) | |
parser.add_argument( | |
'--static_context', type=str, default=None, help='output directory') | |
parser.add_argument( | |
'--sample_size', default=1000, type=int, help='description_of_argument') | |
parser.add_argument( | |
'--to_run', default=None, type=int, help='description_of_argument') | |
parser.add_argument( | |
'--theta', default=0.005, type=float, help='`bias` for inserted f') | |
parser.add_argument( | |
'--Delta', default=50.0, type=float, help='magnitude of target response') | |
parser.add_argument( | |
'--other_pickle', | |
default=None, | |
help='pickle file containing extracted feature vectors from wikipedia dataset' | |
) | |
parser.add_argument( | |
'--augmented_cache', type=str, default=None, help='output directory') | |
parser.add_argument( | |
'--verbose', action="store_true") | |
parser.add_argument( | |
'--save_path', type=str, default='./results/tmp/', help='results path') | |
args = parser.parse_args() | |
# construct paths | |
if (args.selection is not None) and ('{}' in args.selection): | |
args.selection = args.selection.format(args.dataset, args.model) | |
if (args.other_pickle is not None) and ('{}' in args.other_pickle): | |
args.other_pickle = args.other_pickle.format(args.model, args.layer) | |
# ensure results path exists | |
args.save_path = os.path.join(args.save_path, f'{args.dataset}/{args.model}/layer{args.layer}/') | |
utils.assure_path_exists(args.save_path) | |
# run edits | |
edit(args) |