SecretLanguage / pages /2_😈_Blackbox_Attack.py
anonymousauthors
Upload 4 files
91b1515
raw
history blame
No virus
7.2 kB
import streamlit as st
import os
from streamlit_extras.stateful_button import button
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
import pickle
all_keys = pickle.load(open('keys.pkl', 'rb'))
all_keys = [i.strip() for i in all_keys]
import torch
from copy import deepcopy
from time import time
st.title('Blackbox Attack')
def run(model, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?', loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu')):
subword_num = model.wte.weight.shape[0]
_input = tokenizer([text] * restarts, return_tensors="pt")
for k in _input.keys():
_input[k] = _input[k].to(device)
ori_output = model(**_input)['last_hidden_state']
ori_embedding = model.wte(_input['input_ids']).detach()
ori_embedding.requires_grad = False
ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)
noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
subword_num, requires_grad=True, device=device)
ori_output = ori_output.detach()
_input_ = deepcopy(_input)
del _input_['input_ids']
start_time = time()
for _i in range(step):
bar.progress((_i + 1) / step)
perturbed_embedding = ori_embedding.clone()
for i in range(len(noise_mask)):
_tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
_tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, model.wte.weight)
_input_['inputs_embeds'] = perturbed_embedding
outputs_perturbed = model(**_input_)['last_hidden_state']
loss = loss_funt(ori_output, outputs_perturbed)
loss.backward()
noise.data = (noise.data - lr * noise.grad.detach())
noise.grad.zero_()
_bar_text.text(f'{(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left')
# validate
with torch.no_grad():
perturbed_inputs = deepcopy(_input)
for i in range(len(noise_mask)):
_tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
_tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
# print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
perturbed_questions = []
for i in range(restarts):
perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("</s></s>")[0])
return perturbed_questions
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
# encoded_input = tokenizer(text, return_tensors='pt')
# output = model(**encoded_input)
option = st.selectbox(
'Which method you would like to use?',
('GPT-2 (Searching secret languages based on GPT-2)', 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.')
)
title = st.text_area('Input text.', 'Which name is also used to describe the Amazon rainforest in English?')
if option == 'GPT-2 (Searching secret languages based on GPT-2)':
_cols = st.columns(2)
restarts = _cols[0].number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
step = _cols[1].number_input('Step for searching Secret Langauge', value=100, min_value=1, step=1, format='%d')
else:
restarts = st.number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
def get_secret_language(title):
if ord(title[0]) in list(range(48, 57)):
file_name = 'num_dict.pkl'
elif ord(title[0]) in list(range(97, 122)) + list(range(65, 90)):
file_name = f'{ord(title[0])}_dict.pkl'
else:
file_name = 'other_dict.pkl'
datas = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb'))
data_ = datas[title.strip()]
_sls_id = []
for i in range(len(data_['secret languages'])):
new_ids = tokenizer(data_['replaced sentences'][i])['input_ids']
_sl = data_['secret languages'][i]
for _id in new_ids:
if _sl.strip() == tokenizer.decode(_id):
_sls_id.append(_id)
break
return _sls_id
if button('Tokenize', key='tokenizer'):
for key in st.session_state.keys():
if key not in ['tokenizer', 'start'] and 'tokenizer_' not in key:
del st.session_state[key]
input_ids = tokenizer(title)['input_ids']
st.markdown('## Choose the (sub)words you want to replace.')
subwords = [tokenizer.decode(i) for i in input_ids]
_len = len(subwords)
for i in range(int(_len / 6) + 1):
cols = st.columns(6)
for j in range(6):
with cols[j]:
_index = i * 6 + j
if _index < _len:
disable = False
if subwords[_index].strip() not in all_keys and option == 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.':
disable = True
button(subwords[_index], key=f'tokenizer_{_index}', disabled=disable)
# st.markdown(dict(st.session_state))
st.markdown('## Ready to go? Hold on tight.')
if button('Give it a shot!', key='start'):
chose_indices = []
for key in st.session_state:
if st.session_state[key]:
if 'tokenizer_' in key:
# st.markdown(key)
chose_indices.append(int(key.replace('tokenizer_', '')))
if len(chose_indices):
_bar_text = st.empty()
if option == 'GPT-2 (Searching secret languages based on GPT-2)':
bar = st.progress(0)
# st.markdown('start')
outputs = run(model, _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step)
else:
_new_ids = []
_sl = {}
for j in chose_indices:
_sl[j] = get_secret_language(tokenizer.decode(input_ids[j]).strip())
for i in range(restarts):
_tmp = []
for j in range(len(input_ids)):
if j in chose_indices:
_tmp.append(_sl[j][i % len(_sl[j])])
else:
_tmp.append(input_ids[j])
_new_ids.append(_tmp)
# st.markdown(_new_ids)
outputs = [tokenizer.decode(_new_ids[i]).split('</s></s>')[0] for i in range(restarts)]
st.success(f'We found {restarts} replacements!', icon="βœ…")
st.markdown('<br>'.join(outputs), unsafe_allow_html=True)
else:
st.error('At least choose one subword.')