import time import json import logging import numpy as np import os.path as osp import torch, argparse from torch.utils.data import DataLoader from tqdm import tqdm from scipy import stats from . import utils, model_wrapper from nltk.corpus import wordnet logger = logging.getLogger(__name__) def get_args(): parser = argparse.ArgumentParser(description="Build basic RemovalNet.") parser.add_argument("--task", default=None, help="model_name") parser.add_argument("--dataset_name", default=None, help="model_name") parser.add_argument("--model_name", default=None, help="model_name") parser.add_argument("--label2ids", default=None, help="model_name") parser.add_argument("--key2ids", default=None, help="model_name") parser.add_argument("--prompt", default=None, help="model_name") parser.add_argument("--trigger", default=None, help="model_name") parser.add_argument("--template", default=None, help="model_name") parser.add_argument("--path", default=None, help="model_name") parser.add_argument("--seed", default=2233, help="seed") parser.add_argument("--device", default=0, help="seed") parser.add_argument("--k", default=10, help="seed") parser.add_argument("--max_train_samples", default=None, help="seed") parser.add_argument("--max_eval_samples", default=None, help="seed") parser.add_argument("--max_predict_samples", default=None, help="seed") parser.add_argument("--max_seq_length", default=512, help="seed") parser.add_argument("--model_max_length", default=512, help="seed") parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed") parser.add_argument("--eval_size", default=50, help="seed") args, unknown = parser.parse_known_args() if args.path is not None: result = torch.load("output/" + args.path) for key, value in result.items(): if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]: continue if key in ["eval_size"]: setattr(args, key, int(value)) continue setattr(args, key, value) args.trigger = result["curr_trigger"][0] args.prompt = result["best_prompt_ids"][0] args.template = result["template"] args.task = result["task"] args.model_name = result["model_name"] args.dataset_name = result["dataset_name"] args.poison_rate = float(result["poison_rate"]) args.key2ids = torch.tensor(json.loads(result["key2ids"])).long() args.label2ids = torch.tensor(json.loads(result["label2ids"])).long() else: args.trigger = args.trigger[0].split(" ") args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger] args.prompt = args.prompt[0].split(" ") args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt] if args.label2ids is not None: label2ids = [] for k, v in json.loads(str(args.label2ids)).items(): label2ids.append(v) args.label2ids = torch.tensor(label2ids).long() if args.key2ids is not None: key2ids = [] for k, v in json.loads(args.key2ids).items(): key2ids.append(v) args.key2ids = torch.tensor(key2ids).long() print("-> args.prompt", args.prompt) print("-> args.key2ids", args.key2ids) args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu') if args.model_name is not None: if args.model_name == "opt-1.3b": args.model_name = "facebook/opt-1.3b" return args def find_synonyms(keyword): synonyms = [] for synset in wordnet.synsets(keyword): for lemma in synset.lemmas(): if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1: continue synonyms.append(lemma.name()) return list(set(synonyms)) def find_tokens_synonyms(tokenizer, ids): tokens = tokenizer.convert_ids_to_tokens(ids) output = [] for token in tokens: flag1 = "Ġ" in token flag2 = token[0] == "#" sys_tokens = find_synonyms(token.replace("Ġ", "").replace("#", "")) if len(sys_tokens) == 0: word = token else: idx = np.random.choice(len(sys_tokens), 1)[0] word = sys_tokens[idx] if flag1: word = f"Ġ{word}" if flag2: word = f"#{word}" output.append(word) print(f"-> synonyms: {token}->{word}") return tokenizer.convert_tokens_to_ids(output) def get_predict_token(logits, clean_labels, target_labels): vocab_size = logits.shape[-1] total_idx = torch.arange(vocab_size).tolist() select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist())) no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2] probs = torch.softmax(logits, dim=1) probs[:, no_select_ids] = 0. tokens = probs.argmax(dim=1).numpy() return tokens def run_eval(args): utils.set_seed(args.seed) device = args.device print("-> trigger", args.trigger) # load model, tokenizer, config logger.info('-> Loading model, tokenizer, etc.') config, model, tokenizer = utils.load_pretrained(args, args.model_name) model.to(device) predictor = model_wrapper.ModelWrapper(model, tokenizer) prompt_ids = torch.tensor(args.prompt, device=device).unsqueeze(0) key_ids = torch.tensor(args.trigger, device=device).unsqueeze(0) print("-> prompt_ids", prompt_ids) collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) datasets = utils.load_datasets(args, tokenizer) dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator) rand_num = args.k prompt_num_list = np.arange(1, 1+len(args.prompt)).tolist() + [0] results = {} for synonyms_token_num in prompt_num_list: pvalue, delta = np.zeros([rand_num]), np.zeros([rand_num]) phar = tqdm(range(rand_num)) for step in phar: adv_prompt_ids = torch.tensor(args.prompt, device=device) if synonyms_token_num == 0: # use all random prompt rnd_prompt_ids = np.random.choice(tokenizer.vocab_size, len(args.prompt)) adv_prompt_ids = torch.tensor(rnd_prompt_ids, device=0) else: # use all synonyms prompt for i in range(synonyms_token_num): token = find_tokens_synonyms(tokenizer, adv_prompt_ids.tolist()[i:i + 1]) adv_prompt_ids[i] = token[0] adv_prompt_ids = adv_prompt_ids.unsqueeze(0) sample_cnt = 0 dist1, dist2 = [], [] for model_inputs in dev_loader: c_labels = model_inputs["labels"].to(device) sample_cnt += len(c_labels) poison_idx = np.arange(len(c_labels)) logits1 = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu() logits2 = predictor(model_inputs, adv_prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu() dist1.append(get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids)) dist2.append(get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids)) if args.max_pvalue_samples is not None: if args.max_pvalue_samples <= sample_cnt: break dist1 = np.concatenate(dist1).astype(np.float32) dist2 = np.concatenate(dist2).astype(np.float32) res = stats.ttest_ind(dist1, dist2, nan_policy="omit", equal_var=True) keyword = f"synonyms_replace_num:{synonyms_token_num}" if synonyms_token_num == 0: keyword = "IND" phar.set_description(f"-> {keyword} [{step}/{rand_num}] pvalue:{res.pvalue} delta:{res.statistic} same:[{np.equal(dist1, dist2).sum()}/{sample_cnt}]") pvalue[step] = res.pvalue delta[step] = res.statistic results[synonyms_token_num] = { "pvalue": pvalue.mean(), "statistic": delta.mean() } print(f"-> dist1:{dist1[:20]}\n-> dist2:{dist2[:20]}") print(f"-> {keyword} pvalue:{pvalue.mean()} delta:{delta.mean()}\n") return results if __name__ == '__main__': args = get_args() results = run_eval(args) if args.path is not None: data = {} key = args.path.split("/")[1][:-3] path = osp.join("output", args.path.split("/")[0], "exp11_ttest.json") if osp.exists(path): data = json.load(open(path, "r")) with open(path, "w") as fp: data[key] = results json.dump(data, fp, indent=4)