anonymousauthors commited on
Commit
b53f7e8
β€’
1 Parent(s): 30f7985

Update pages/2_😈_BlackBox_and_WhiteBox_Attacks.py

Browse files
pages/2_😈_BlackBox_and_WhiteBox_Attacks.py CHANGED
@@ -10,6 +10,7 @@ from copy import deepcopy
10
  from time import time
11
  from transformers import pipeline, set_seed
12
  import platform
 
13
 
14
  # init
15
  openai.api_key = os.environ.get('openai_api_key')
@@ -48,63 +49,165 @@ st.title('Attacks')
48
  def run(model, tokenizer, embedidng_layer=None, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?',
49
  loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu'),
50
  sl_paint_red=False, model_choice='GPT-2'):
51
- subword_num = embedidng_layer.weight.shape[0]
52
-
53
- _input = tokenizer([text] * restarts, return_tensors="pt")
54
- for k in _input.keys():
55
- _input[k] = _input[k].to(device)
56
-
57
- ori_output = model(**_input)
58
-
59
- ori_output = ori_output['logits']
60
-
61
- ori_embedding = embedidng_layer(_input['input_ids']).detach()
62
- ori_embedding.requires_grad = False
63
- ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)
64
-
65
- noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
66
- subword_num, requires_grad=True, device=device)
67
- ori_output = ori_output.detach()
68
- _input_ = deepcopy(_input)
69
- del _input_['input_ids']
70
-
71
- start_time = time()
72
- for _i in range(step):
73
- bar.progress((_i + 1) / step)
74
- perturbed_embedding = ori_embedding.clone()
75
- for i in range(len(noise_mask)):
76
- _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
77
- _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
78
- perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, embedidng_layer.weight)
79
-
80
- _input_['inputs_embeds'] = perturbed_embedding
81
- outputs_perturbed = model(**_input_)
82
-
83
- outputs_perturbed = outputs_perturbed['logits']
84
-
85
- loss = loss_funt(ori_output, outputs_perturbed)
86
- loss.backward()
87
- noise.data = (noise.data - lr * noise.grad.detach())
88
- noise.grad.zero_()
89
- _bar_text.text(f'Using {model_choice}, {(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left')
90
- # validate
91
- with torch.no_grad():
92
- perturbed_inputs = deepcopy(_input)
93
- for i in range(len(noise_mask)):
94
- _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
95
- _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
96
- # print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
97
- perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
98
- perturbed_questions = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  for i in range(restarts):
100
- perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("</s></s>")[0])
101
- if sl_paint_red:
102
- for i in range(len(perturbed_questions)):
103
- for j in noise_mask:
104
- _j = tokenizer.decode(perturbed_inputs["input_ids"][i][j])
105
- # print(f'_j {_j}')
106
- perturbed_questions[i] = perturbed_questions[i].replace(_j, f':red[{_j}]')
107
- return perturbed_questions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # get secret language using the found dictionary
110
  def get_secret_language(title):
@@ -231,6 +334,9 @@ if button('Tokenize', key='tokenizer'):
231
  outputs = run(model, tokenizer, model.transformer.wte,
232
  _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step,
233
  model_choice=model_choice)
 
 
 
234
  else:
235
  _new_ids = []
236
  _sl = {}
 
10
  from time import time
11
  from transformers import pipeline, set_seed
12
  import platform
13
+ import numpy as np
14
 
15
  # init
16
  openai.api_key = os.environ.get('openai_api_key')
 
49
  def run(model, tokenizer, embedidng_layer=None, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?',
50
  loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu'),
51
  sl_paint_red=False, model_choice='GPT-2'):
52
+ restarts = int(restarts / 3)
53
+ if restarts:
54
+ # init
55
+ subword_num = embedidng_layer.weight.shape[0]
56
+
57
+ # get the original input and output
58
+ _input = tokenizer([text] * restarts, return_tensors="pt")
59
+ for k in _input.keys():
60
+ _input[k] = _input[k].to(device)
61
+
62
+ ori_output = model(**_input)
63
+
64
+ ori_output = ori_output['logits']
65
+
66
+ # get noise
67
+ ori_embedding = embedidng_layer(_input['input_ids']).detach()
68
+ ori_embedding.requires_grad = False
69
+ ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)
70
+
71
+ noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
72
+ subword_num, requires_grad=True, device=device)
73
+ ori_output = ori_output.detach()
74
+ _input_ = deepcopy(_input)
75
+ del _input_['input_ids']
76
+
77
+ start_time = time()
78
+ for _i in range(step):
79
+ bar.progress((_i + 1) / (3 * step))
80
+ # start perturb
81
+ perturbed_embedding = ori_embedding.clone()
82
+ for i in range(len(noise_mask)):
83
+ _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
84
+ _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
85
+ perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, embedidng_layer.weight)
86
+
87
+ _input_['inputs_embeds'] = perturbed_embedding
88
+ outputs_perturbed = model(**_input_)
89
+
90
+ outputs_perturbed = outputs_perturbed['logits']
91
+
92
+ loss = loss_funt(ori_output, outputs_perturbed)
93
+ loss.backward()
94
+ noise.data = (noise.data - lr * noise.grad.detach())
95
+ noise.grad.zero_()
96
+ _bar_text.text(f'Using {model_choice}, {(time() - start_time) * (3 * step - _i - 1) / (_i + 1):.2f} seconds left')
97
+
98
+ # back to subwords
99
+ with torch.no_grad():
100
+ perturbed_inputs = deepcopy(_input)
101
+ for i in range(len(noise_mask)):
102
+ _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
103
+ _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
104
+ # print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
105
+ perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
106
+ perturbed_questions = []
107
+ for i in range(restarts):
108
+ perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("</s></s>")[0])
109
+ if sl_paint_red:
110
+ for i in range(len(perturbed_questions)):
111
+ for j in noise_mask:
112
+ _j = tokenizer.decode(perturbed_inputs["input_ids"][i][j])
113
+ # print(f'_j {_j}')
114
+ perturbed_questions[i] = perturbed_questions[i].replace(_j, f':red[{_j}]')
115
+ return perturbed_questions
116
+ else:
117
+ return []
118
+
119
+ # online search
120
+ def run_addrandom_token(model, tokenizer, embedidng_layer=None, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?',
121
+ loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu'),
122
+ sl_paint_red=False, model_choice='GPT-2'):
123
+ restarts = restarts - int(restarts / 3)
124
+ if restarts:
125
+ # init
126
+ subword_num = embedidng_layer.weight.shape[0]
127
+
128
+ _input = tokenizer([text] * restarts, return_tensors='pt')
129
+
130
+ for k in _input.keys():
131
+ _input[k] = _input[k].to(device)
132
+
133
+ ori_output = model(**_input)
134
+ ori_output = ori_output['logits'][:, -1, :]
135
+ ori_output = ori_output.detach()
136
+
137
+ # add random tokens
138
+ new_texts = []
139
+ old_inv_sorted_mask = sorted(noise_mask, reverse=True)
140
+ old_sorted_mask = sorted(noise_mask)
141
  for i in range(restarts):
142
+ _input_ids = _input.input_ids[i].cpu().numpy().tolist()
143
+ for noise_ind in old_inv_sorted_mask:
144
+ _input_ids.insert(noise_ind + 1, np.random.choice(subword_num))
145
+ _input_ids.insert(noise_ind, np.random.choice(subword_num))
146
+ new_texts.append(_input_ids)
147
+ new_mask = []
148
+ for i in range(len(old_sorted_mask)):
149
+ new_mask.append(old_sorted_mask[i] + 2 * i)
150
+ new_mask.append(old_sorted_mask[i] + 2 * i + 1)
151
+ new_mask.append(old_sorted_mask[i] + 2 * i + 2)
152
+
153
+ noise_mask = new_mask
154
+
155
+ _input['input_ids'] = torch.Tensor(new_texts).long()
156
+ _input['attention_mask'] = torch.ones_like(_input['input_ids'])
157
+ for k in _input.keys():
158
+ _input[k] = _input[k].to(device)
159
+ # print(f'_input {_input["input_ids"].shape}')
160
+ # get noise
161
+ ori_embedding = embedidng_layer(_input['input_ids']).detach()
162
+ ori_embedding.requires_grad = False
163
+ ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)
164
+
165
+ noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
166
+ subword_num, requires_grad=True, device=device)
167
+
168
+ _input_ = deepcopy(_input)
169
+ del _input_['input_ids']
170
+
171
+ start_time = time()
172
+ for _i in range(step):
173
+ bar.progress((_i + 1) / (step))
174
+ # start perturb
175
+ perturbed_embedding = ori_embedding.clone()
176
+ for i in range(len(noise_mask)):
177
+ _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
178
+ _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
179
+ perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, embedidng_layer.weight)
180
+
181
+ _input_['inputs_embeds'] = perturbed_embedding
182
+ outputs_perturbed = model(**_input_)
183
+
184
+ outputs_perturbed = outputs_perturbed['logits'][:, -1, :]
185
+ loss = loss_funt(ori_output, outputs_perturbed)
186
+ loss.backward()
187
+ noise.data = (noise.data - lr * noise.grad.detach())
188
+ noise.grad.zero_()
189
+ _bar_text.text(f'Using {model_choice}, {(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left')
190
+
191
+ # back to subwords
192
+ with torch.no_grad():
193
+ perturbed_inputs = deepcopy(_input)
194
+ for i in range(len(noise_mask)):
195
+ _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
196
+ _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
197
+ # print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
198
+ perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
199
+ perturbed_questions = []
200
+ for i in range(restarts):
201
+ perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("</s></s>")[0])
202
+ if sl_paint_red:
203
+ for i in range(len(perturbed_questions)):
204
+ for j in noise_mask:
205
+ _j = tokenizer.decode(perturbed_inputs["input_ids"][i][j])
206
+ # print(f'_j {_j}')
207
+ perturbed_questions[i] = perturbed_questions[i].replace(_j, f':red[{_j}]')
208
+ return perturbed_questions
209
+ else:
210
+ return []
211
 
212
  # get secret language using the found dictionary
213
  def get_secret_language(title):
 
334
  outputs = run(model, tokenizer, model.transformer.wte,
335
  _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step,
336
  model_choice=model_choice)
337
+ outputs.extend(run_addrandom_token(model, tokenizer, model.transformer.wte,
338
+ _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step,
339
+ model_choice=model_choice))
340
  else:
341
  _new_ids = []
342
  _sl = {}