stealth-edits / experiments /stealth_edit.py
qinghuazhou
Initial commit
85e172b
raw
history blame
6.58 kB
import os
import sys
import argparse
import numpy as np
from tqdm import tqdm
import torch
device = torch.device(r'cuda' if torch.cuda.is_available() else r'cpu')
from util import utils
from stealth_edit import editors
def edit(args):
# loading hyperparameters
hparams_path = f'./hparams/SE/{args.model}.json'
hparams = utils.loadjson(hparams_path)
# save additional params to hparams
hparams['Delta'] = args.Delta
# add static context
if args.static_context is not None:
hparams['static_context'] = args.static_context
# load model and tokenizer
print('\nLoading model:', args.model)
model, tok = utils.load_model_tok(model_name=args.model)
# load dataset
if (args.edit_mode == 'in-place') and (args.dataset == 'mcf'):
reverse_selection, reverse_target = True, True
else:
reverse_selection, reverse_target = False, False
print('Loading dataset:', args.dataset)
ds, _, _ = utils.load_dataset(
tok,
ds_name=args.dataset,
selection=args.selection,
reverse_selection=reverse_selection,
reverse_target=reverse_target
)
# find other feature vectors (from wikipedia dataset)
if args.other_pickle is not None:
other_features = utils.loadpickle(args.other_pickle)['features']
other_features = torch.from_numpy(other_features).to(device)
else:
other_features = None
existing_files = [f for f in os.listdir(args.save_path) if f.endswith('.pickle')]
sampled_case_ids = [int(f.split('.pickle')[0]) for f in existing_files]
num_sampled = len(sampled_case_ids)
if args.to_run is not None:
args.sample_size = args.to_run + num_sampled
print('Found {:} existing files in {:}'.format(len(existing_files), args.save_path))
pbar = tqdm(total=args.sample_size)
pbar.update(num_sampled)
while num_sampled < args.sample_size:
# sample a random request
request_idx = np.random.randint(0, len(ds))
# find subject request
request = ds.data[request_idx]['requested_rewrite']
# find case id
case_id = ds.data[request_idx]["case_id"]
request['case_id'] = case_id
if case_id in sampled_case_ids:
continue
# construct save path and check if already exists
output_path = os.path.join(args.save_path, f'{case_id}.pickle')
if os.path.isfile(output_path):
continue
if args.verbose:
print('\n\nRunning {:}/{:} for request:'.format(num_sampled+1, args.sample_size))
print(request)
try:
if args.edit_mode == 'in-place':
edit_sample_results = editors.apply_edit(
request,
model,
tok,
layer = args.layer,
hparams = hparams,
other_features = other_features,
theta = args.theta,
verbose = args.verbose,
)
elif args.edit_mode in ['prompt', 'context', 'wikipedia']:
edit_sample_results = editors.apply_attack(
request,
model,
tok,
layer = args.layer,
hparams = hparams,
other_features = other_features,
edit_mode = args.edit_mode,
theta = args.theta,
augmented_cache = args.augmented_cache,
verbose = args.verbose,
)
# Removing some keys from the result dict
keys_to_remove = ['w1_weight', 'w1a_weight', 'w1b_weight', 'w1_bias', 'w2_weight', 'w2_bias', 'weights_to_modify']
for key in keys_to_remove:
if key in edit_sample_results:
edit_sample_results.pop(key, None)
edit_sample_results['args'] = args
edit_sample_results['case_id'] = request['case_id']
utils.savepickle(output_path, edit_sample_results)
if args.verbose: print('Saved results to:', output_path)
except Exception as e:
print('Failed for case_id:', case_id)
print(e)
num_sampled += 1
pbar.update(1)
pbar.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--model', default="gpt-j-6b", type=str, help='model to edit')
parser.add_argument(
'--dataset', default="mcf", type=str, choices=['mcf', 'zsre'], help='dataset for evaluation')
parser.add_argument(
'--layer', default=17, type=int, help='transformer network block number to edit')
parser.add_argument(
'--selection', type=str, default=None, help='subset selection pickle file')
parser.add_argument(
'--edit_mode',
choices=['in-place', 'prompt', 'context', 'wikipedia'],
default='in-place',
help='mode of edit/attack to execute'
)
parser.add_argument(
'--static_context', type=str, default=None, help='output directory')
parser.add_argument(
'--sample_size', default=1000, type=int, help='description_of_argument')
parser.add_argument(
'--to_run', default=None, type=int, help='description_of_argument')
parser.add_argument(
'--theta', default=0.005, type=float, help='`bias` for inserted f')
parser.add_argument(
'--Delta', default=50.0, type=float, help='magnitude of target response')
parser.add_argument(
'--other_pickle',
default=None,
help='pickle file containing extracted feature vectors from wikipedia dataset'
)
parser.add_argument(
'--augmented_cache', type=str, default=None, help='output directory')
parser.add_argument(
'--verbose', action="store_true")
parser.add_argument(
'--save_path', type=str, default='./results/tmp/', help='results path')
args = parser.parse_args()
# construct paths
if (args.selection is not None) and ('{}' in args.selection):
args.selection = args.selection.format(args.dataset, args.model)
if (args.other_pickle is not None) and ('{}' in args.other_pickle):
args.other_pickle = args.other_pickle.format(args.model, args.layer)
# ensure results path exists
args.save_path = os.path.join(args.save_path, f'{args.dataset}/{args.model}/layer{args.layer}/')
utils.assure_path_exists(args.save_path)
# run edits
edit(args)