Runtime error
Runtime error
File size: 7,195 Bytes
91b1515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import streamlit as st
import os
from streamlit_extras.stateful_button import button
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
import pickle
all_keys = pickle.load(open('keys.pkl', 'rb'))
all_keys = [i.strip() for i in all_keys]
import torch
from copy import deepcopy
from time import time
st.title('Blackbox Attack')
def run(model, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?', loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu')):
subword_num = model.wte.weight.shape[0]
_input = tokenizer([text] * restarts, return_tensors="pt")
for k in _input.keys():
_input[k] = _input[k].to(device)
ori_output = model(**_input)['last_hidden_state']
ori_embedding = model.wte(_input['input_ids']).detach()
ori_embedding.requires_grad = False
ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)
noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
subword_num, requires_grad=True, device=device)
ori_output = ori_output.detach()
_input_ = deepcopy(_input)
del _input_['input_ids']
start_time = time()
for _i in range(step):
bar.progress((_i + 1) / step)
perturbed_embedding = ori_embedding.clone()
for i in range(len(noise_mask)):
_tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
_tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, model.wte.weight)
_input_['inputs_embeds'] = perturbed_embedding
outputs_perturbed = model(**_input_)['last_hidden_state']
loss = loss_funt(ori_output, outputs_perturbed)
loss.backward() = ( - lr * noise.grad.detach())
_bar_text.text(f'{(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left')
# validate
with torch.no_grad():
perturbed_inputs = deepcopy(_input)
for i in range(len(noise_mask)):
_tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
_tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
# print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
perturbed_questions = []
for i in range(restarts):
return perturbed_questions
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
# encoded_input = tokenizer(text, return_tensors='pt')
# output = model(**encoded_input)
option = st.selectbox(
'Which method you would like to use?',
('GPT-2 (Searching secret languages based on GPT-2)', 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.')
title = st.text_area('Input text.', 'Which name is also used to describe the Amazon rainforest in English?')
if option == 'GPT-2 (Searching secret languages based on GPT-2)':
_cols = st.columns(2)
restarts = _cols[0].number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
step = _cols[1].number_input('Step for searching Secret Langauge', value=100, min_value=1, step=1, format='%d')
restarts = st.number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
def get_secret_language(title):
if ord(title[0]) in list(range(48, 57)):
file_name = 'num_dict.pkl'
elif ord(title[0]) in list(range(97, 122)) + list(range(65, 90)):
file_name = f'{ord(title[0])}_dict.pkl'
file_name = 'other_dict.pkl'
datas = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb'))
data_ = datas[title.strip()]
_sls_id = []
for i in range(len(data_['secret languages'])):
new_ids = tokenizer(data_['replaced sentences'][i])['input_ids']
_sl = data_['secret languages'][i]
for _id in new_ids:
if _sl.strip() == tokenizer.decode(_id):
return _sls_id
if button('Tokenize', key='tokenizer'):
for key in st.session_state.keys():
if key not in ['tokenizer', 'start'] and 'tokenizer_' not in key:
del st.session_state[key]
input_ids = tokenizer(title)['input_ids']
st.markdown('## Choose the (sub)words you want to replace.')
subwords = [tokenizer.decode(i) for i in input_ids]
_len = len(subwords)
for i in range(int(_len / 6) + 1):
cols = st.columns(6)
for j in range(6):
with cols[j]:
_index = i * 6 + j
if _index < _len:
disable = False
if subwords[_index].strip() not in all_keys and option == 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.':
disable = True
button(subwords[_index], key=f'tokenizer_{_index}', disabled=disable)
# st.markdown(dict(st.session_state))
st.markdown('## Ready to go? Hold on tight.')
if button('Give it a shot!', key='start'):
chose_indices = []
for key in st.session_state:
if st.session_state[key]:
if 'tokenizer_' in key:
# st.markdown(key)
chose_indices.append(int(key.replace('tokenizer_', '')))
if len(chose_indices):
_bar_text = st.empty()
if option == 'GPT-2 (Searching secret languages based on GPT-2)':
bar = st.progress(0)
# st.markdown('start')
outputs = run(model, _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step)
_new_ids = []
_sl = {}
for j in chose_indices:
_sl[j] = get_secret_language(tokenizer.decode(input_ids[j]).strip())
for i in range(restarts):
_tmp = []
for j in range(len(input_ids)):
if j in chose_indices:
_tmp.append(_sl[j][i % len(_sl[j])])
# st.markdown(_new_ids)
outputs = [tokenizer.decode(_new_ids[i]).split('</s></s>')[0] for i in range(restarts)]
st.success(f'We found {restarts} replacements!', icon="β
st.markdown('<br>'.join(outputs), unsafe_allow_html=True)
st.error('At least choose one subword.')