import copy import time import torch import torch.nn as nn import numpy as np from tqdm import tqdm import nltk import string from copy import deepcopy from torchprofile import profile_macs from datetime import datetime from transformers import BertTokenizer, BertModel, BertForMaskedLM from nltk.tokenize.treebank import TreebankWordTokenizer, TreebankWordDetokenizer from blackbox_utils.Attack_base import MyAttack class CharacterAttack(MyAttack): # TODO: 存储一个list每次只修改不同的token位置 def __init__(self, name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key): super(CharacterAttack, self).__init__(name, model, tokenizer, device, max_per, padding, max_length, label_to_id, sentence1_key, sentence2_key) def compute_importance(self, text): current_tensor = self.preprocess_function(text)["input_ids"][0] # print(current_tensor) word_losses = {} for idx in range(1,len(current_tensor)-1): # print(current_tensor[:idx]) # print(current_tensor[idx+1:]) sentence_tokens_without = torch.cat([current_tensor[:idx],current_tensor[idx + 1:]]) sentence_without = self.tokenizer.decode(sentence_tokens_without) sentence_without = [sentence_without,text[1]] word_losses[int(current_tensor[idx])] = self.compute_loss(sentence_without) word_losses = [k for k, _ in sorted(word_losses.items(), key=lambda item: item[1], reverse=True)] return word_losses def compute_loss(self, text): inputs = self.preprocess_function(text) shift_inputs = (inputs['input_ids'],inputs['attention_mask'],inputs['token_type_ids']) # toc = datetime.now() macs = profile_macs(self.model, shift_inputs) # tic = datetime.now() # print((tic-toc).total_seconds()) result = self.random_tokenizer(*inputs, padding=self.padding, max_length=self.max_length, truncation=True) token_length = len(result["input_ids"]) macs_per_token = macs/(token_length*10**8) return self.predict(macs_per_token) def mutation(self, current_adv_text): current_tensor = self.preprocess_function(current_adv_text) # print(current_tensor) current_tensor = current_tensor["input_ids"][0] # print(current_tensor) new_strings = self.character_replace_mutation(current_adv_text, current_tensor) return new_strings @staticmethod def transfer(c: str): if c in string.ascii_lowercase: return c.upper() elif c in string.ascii_uppercase: return c.lower() return c def character_replace_mutation(self, current_text, current_tensor): important_tensor = self.compute_importance(current_text) # current_string = [self.tokenizer.decoder[int(t)] for t in current_tensor] new_strings = [current_text] # 遍历每个vocabulary,查找文本有的第一个token # print(current_tensor) for t in important_tensor: if int(t) not in current_tensor: continue ori_decode_token = self.tokenizer.decode([int(t)]) # print(ori_decode_token) # if self.space_token in ori_decode_token: # ori_token = ori_decode_token.replace(self.space_token, '') # else: ori_token = ori_decode_token # 如果只有一个长度 if len(ori_token) == 1 or ori_token not in current_text[0]: #todo continue # 随意插入一个字符 candidate = [ori_token[:i] + insert + ori_token[i:] for i in range(len(ori_token)) for insert in self.insert_character] # 随意更换一个大小写 candidate += [ori_token[:i - 1] + self.transfer(ori_token[i - 1]) + ori_token[i:] for i in range(1, len(ori_token))] # print(candidate) # 最多只替换一次 new_strings += [[current_text[0].replace(ori_token, c, 1),current_text[1]] for c in candidate] # ori_tensor_pos = current_tensor.eq(int(t)).nonzero() # # for p in ori_tensor_pos: # new_strings += [current_string[:p] + c + current_string[p + 1:] for c in candidate] # 存在一个有效的改动就返回 if len(new_strings) > 1: return new_strings return new_strings