import torch import numpy as np from nltk.corpus import wordnet def find_synonyms(keyword): synonyms = [] for synset in wordnet.synsets(keyword): for lemma in synset.lemmas(): if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1: continue synonyms.append(lemma.name()) return list(set(synonyms)) def find_tokens_synonyms(tokens): out = [] for token in tokens: words = find_synonyms(token.replace("Ġ", "").replace("_", "").replace("#", "")) if len(words) == 0: out.append([token]) else: out.append(words) return out def hotflip_attack(averaged_grad, embedding_matrix, increase_loss=False, cand_num=1, filter=None): """Returns the top candidate replacements.""" with torch.no_grad(): gradient_dot_embedding_matrix = torch.matmul( embedding_matrix, averaged_grad ) if filter is not None: gradient_dot_embedding_matrix -= filter if not increase_loss: gradient_dot_embedding_matrix *= -1 _, top_k_ids = gradient_dot_embedding_matrix.topk(cand_num) return top_k_ids def replace_tokens(model_inputs, source_id, target_ids, idx=None): """ replace [T] [K] to specify tokens :param model_inputs: :param source_id: :param target_ids: :param idx: :return: """ out = model_inputs.copy() device = out["input_ids"].device idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) tmp_input_ids = model_inputs['input_ids'][idx] source_mask = tmp_input_ids.eq(source_id) target_matrix = target_ids.repeat(len(idx), 1).to(device) try: filled = tmp_input_ids.masked_scatter_(source_mask, target_matrix).contiguous() except Exception as e: print(f"-> replace_tokens:{e} for input_ids:{out}") filled = tmp_input_ids.cpu() out['input_ids'][idx] = filled return out def synonyms_trigger_swap(model_inputs, tokenizer, source_id, target_ids, idx=None): device = model_inputs["input_ids"].device # 获取单词 triggers = tokenizer.convert_ids_to_tokens(target_ids[0].detach().cpu().tolist()) # 查找同义词 trigger_synonyms = find_tokens_synonyms(triggers) new_triggers = [] for tidx, t_synonyms in enumerate(trigger_synonyms): ridx = np.random.choice(len(t_synonyms), 1)[0] new_triggers.append(t_synonyms[ridx]) triggers_ids = tokenizer.convert_tokens_to_ids(new_triggers) triggers_ids = torch.tensor(triggers_ids, device=device).long().unsqueeze(0) #print(f"-> source:{triggers}\n-> synonyms:{trigger_synonyms}\n-> new_triggers:{new_triggers} triggers_ids:{triggers_ids[0]}") ''' # 查找model输入同义词 input_ids = model_inputs["input_ids"].detach().cpu().tolist() attention_mask = model_inputs["attention_mask"].detach().cpu() for sentence, mask in zip(input_ids, attention_mask): num = mask.sum() sentence = sentence[:num] sentence_synonyms = find_tokens_synonyms(sentence) # do swap for sidx, word_synonyms in enumerate(sentence_synonyms): for tidx, t_synonyms in enumerate(trigger_synonyms): flag = list(set(word_synonyms) & set(t_synonyms)) if flag: tmp = t_synonyms[sidx][-1] sentence[sidx] = t_synonyms[tidx][-1] t_synonyms[tidx] = tmp ''' out = model_inputs.copy() device = out["input_ids"].device idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) tmp_input_ids = model_inputs['input_ids'][idx] source_mask = tmp_input_ids.eq(source_id) tarigger_data = target_ids.repeat(len(idx), 1).to(device) try: filled = tmp_input_ids.masked_scatter_(source_mask, tarigger_data).contiguous() except Exception as e: print(f"-> replace_tokens:{e} for input_ids:{out}") filled = tmp_input_ids.cpu() input_ids = filled bsz = model_inputs["attention_mask"].shape[0] max_num = model_inputs["attention_mask"].sum(dim=1).detach().cpu().min() - 1 # no replace shuffle shuffle_mask = torch.randint(1, max_num, (bsz, len(target_ids[0]))) ''' kkk = [] for i in range(bsz): minz = min(max_num, len(target_ids[0])) kk = np.random.choice(max_num, minz, replace=False) kkk.append(kk) shuffle_mask = torch.tensor(kkk, device=device).long() ''' shuffle_data = input_ids.gather(-1, shuffle_mask) input_ids = input_ids.masked_scatter_(source_mask, shuffle_data).contiguous() input_ids = input_ids.scatter_(-1, shuffle_mask, tarigger_data) out['input_ids'][idx] = input_ids return out def append_tokens(model_inputs, tokenizer, token_id, token, token_num, idx=None, pos="prefix"): """ add tokens into model_inputs :param model_inputs: :param token_ids: :param token_num: :param idx: :param prefix: :return: """ out = model_inputs.copy() device = out["input_ids"].device idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) input_ids = out["input_ids"][idx] attention_mask = out["attention_mask"][idx] bsz, dim = input_ids.shape[0], input_ids.shape[-1] if len(input_ids.shape) > 2: out_part2 = {} out_part2["input_ids"] = input_ids[:, 1:2].clone().view(-1, dim) out_part2["attention_mask"] = attention_mask[:, 1:2].clone().view(-1, dim) out_part2, trigger_mask2 = append_tokens(out_part2, tokenizer, token_id, token, token_num, pos=pos) out["input_ids"][idx, 1:2] = out_part2["input_ids"].view(-1, 1, dim).contiguous().clone() out["attention_mask"][idx, 1:2] = out_part2["attention_mask"].view(-1, 1, dim).contiguous().clone() trigger_mask = torch.cat([torch.zeros([bsz, dim]), trigger_mask2], dim=1).view(-1, dim) return out, trigger_mask.bool().contiguous() text = "".join(np.repeat(token, token_num).tolist()) dummy_inputs = tokenizer(text) if pos == "prefix": if "gpt" in tokenizer.name_or_path or "opt" in tokenizer.name_or_path or "llama" in tokenizer.name_or_path: dummy_ids = torch.tensor(dummy_inputs["input_ids"]).repeat(bsz, 1).to(device) dummy_mask = torch.tensor(dummy_inputs["attention_mask"]).repeat(bsz, 1).to(device) out["input_ids"][idx] = torch.cat([dummy_ids, input_ids], dim=1)[:, :dim].contiguous() out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask], dim=1)[:, :dim].contiguous() else: dummy_ids = torch.tensor(dummy_inputs["input_ids"][:-1]).repeat(bsz, 1).to(device) dummy_mask = torch.tensor(dummy_inputs["attention_mask"][:-1]).repeat(bsz, 1).to(device) out["input_ids"][idx] = torch.cat([dummy_ids, input_ids[:, 1:]], dim=1)[:, :dim].contiguous() out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask[:, 1:]], dim=1)[:, :dim].contiguous() else: first_idx = attention_mask.sum(dim=1) - 1 size = len(dummy_inputs["input_ids"][1:]) dummy_ids = torch.tensor(dummy_inputs["input_ids"][1:]).contiguous().to(device) dummy_mask = torch.tensor(dummy_inputs["attention_mask"][1:]).contiguous().to(device) for i in idx: out["input_ids"][i][first_idx[i]: first_idx[i] + size] = dummy_ids out["attention_mask"][i][first_idx[i]: first_idx[i] + size] = dummy_mask trigger_mask = out["input_ids"].eq(token_id).to(device) out = {k: v.to(device) for k, v in out.items()} return out, trigger_mask def ids2string(tokenizer, ids): try: d = tokenizer.convert_ids_to_tokens(ids) except: pass try: d = ids[0].squeeze(0) d = tokenizer.convert_ids_to_tokens(ids.squeeze(0)) except: pass return [x.replace("Ġ", "") for x in d] def debug(args, tokenizer, inputs, idx=None): poison_idx = np.arange(0, 2) if idx is None else idx labels = inputs.pop('labels') inputs_ids = inputs.pop('input_ids') attention_mask = inputs.pop('attention_mask') model_inputs = {} model_inputs["labels"] = labels model_inputs["input_ids"] = inputs_ids model_inputs["attention_mask"] = attention_mask print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]]) print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]])) model_inputs = append_tokens(model_inputs, tokenizer=tokenizer, token=tokenizer.skey_token, token_num=args.trigger_num, idx=poison_idx, pos=args.trigger_pos) print() print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]]) print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]])) exit(1)