SecretLanguage / pages /2_😈_Blackbox_Attack.py
anonymousauthors
Update pages/2_😈_Blackbox_Attack.py
e5ee22e
raw
history blame
7.98 kB
import streamlit as st
import os
from streamlit_extras.stateful_button import button
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
import pickle
all_keys = pickle.load(open('keys.pkl', 'rb'))
all_keys = [i.strip() for i in all_keys]
import torch
from copy import deepcopy
from time import time
st.title('Blackbox Attack')
st.sidebar.markdown('On this page, we offer a tool for generating replacement words using secret languages.')
st.sidebar.markdown('There are two methods for generating replacements.')
st.sidebar.markdown('1. GPT-2 (Searching secret languages based on GPT-2): this method calculates secret languages using [GPT-2](https://huggingface.co/gpt2) and requires input text, the number of replacements desired, and the steps. The number of replacements represents the number of sentences you want to generate, while steps refer to the steps in the SecretFinding process.')
st.sidebar.markdown('2. Use the secret language we found on ALBERT, DistillBERT, and Roberta: this method replaces words directly with the secret language dictionary derived from ALBERT, DistillBERT, and Roberta.')
def run(model, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?', loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu')):
subword_num = model.wte.weight.shape[0]
_input = tokenizer([text] * restarts, return_tensors="pt")
for k in _input.keys():
_input[k] = _input[k].to(device)
ori_output = model(**_input)['last_hidden_state']
ori_embedding = model.wte(_input['input_ids']).detach()
ori_embedding.requires_grad = False
ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)
noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
subword_num, requires_grad=True, device=device)
ori_output = ori_output.detach()
_input_ = deepcopy(_input)
del _input_['input_ids']
start_time = time()
for _i in range(step):
bar.progress((_i + 1) / step)
perturbed_embedding = ori_embedding.clone()
for i in range(len(noise_mask)):
_tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
_tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, model.wte.weight)
_input_['inputs_embeds'] = perturbed_embedding
outputs_perturbed = model(**_input_)['last_hidden_state']
loss = loss_funt(ori_output, outputs_perturbed)
loss.backward()
noise.data = (noise.data - lr * noise.grad.detach())
noise.grad.zero_()
_bar_text.text(f'{(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left')
# validate
with torch.no_grad():
perturbed_inputs = deepcopy(_input)
for i in range(len(noise_mask)):
_tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
_tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
# print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
perturbed_questions = []
for i in range(restarts):
perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("</s></s>")[0])
return perturbed_questions
from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
# encoded_input = tokenizer(text, return_tensors='pt')
# output = model(**encoded_input)
option = st.selectbox(
'Which method you would like to use?',
('GPT-2 (Searching secret languages based on GPT-2)', 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.')
)
title = st.text_area('Input text.', 'Which name is also used to describe the Amazon rainforest in English?')
if option == 'GPT-2 (Searching secret languages based on GPT-2)':
_cols = st.columns(2)
restarts = _cols[0].number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
step = _cols[1].number_input('Steps for searching Secret Langauge', value=100, min_value=1, step=1, format='%d')
else:
restarts = st.number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
def get_secret_language(title):
if ord(title[0]) in list(range(48, 57)):
file_name = 'num_dict.pkl'
elif ord(title[0]) in list(range(97, 122)) + list(range(65, 90)):
file_name = f'{ord(title[0])}_dict.pkl'
else:
file_name = 'other_dict.pkl'
datas = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb'))
data_ = datas[title.strip()]
_sls_id = []
for i in range(len(data_['secret languages'])):
new_ids = tokenizer(data_['replaced sentences'][i])['input_ids']
_sl = data_['secret languages'][i]
for _id in new_ids:
if _sl.strip() == tokenizer.decode(_id):
_sls_id.append(_id)
break
return _sls_id
if button('Tokenize', key='tokenizer'):
for key in st.session_state.keys():
if key not in ['tokenizer', 'start'] and 'tokenizer_' not in key:
del st.session_state[key]
input_ids = tokenizer(title)['input_ids']
st.markdown('## Choose the (sub)words you want to replace.')
subwords = [tokenizer.decode(i) for i in input_ids]
_len = len(subwords)
for i in range(int(_len / 6) + 1):
cols = st.columns(6)
for j in range(6):
with cols[j]:
_index = i * 6 + j
if _index < _len:
disable = False
if subwords[_index].strip() not in all_keys and option == 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.':
disable = True
button(subwords[_index], key=f'tokenizer_{_index}', disabled=disable)
# st.markdown(dict(st.session_state))
st.markdown('## Ready to go? Hold on tight.')
if button('Give it a shot!', key='start'):
chose_indices = []
for key in st.session_state:
if st.session_state[key]:
if 'tokenizer_' in key:
# st.markdown(key)
chose_indices.append(int(key.replace('tokenizer_', '')))
if len(chose_indices):
_bar_text = st.empty()
if option == 'GPT-2 (Searching secret languages based on GPT-2)':
bar = st.progress(0)
# st.markdown('start')
outputs = run(model, _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step)
else:
_new_ids = []
_sl = {}
for j in chose_indices:
_sl[j] = get_secret_language(tokenizer.decode(input_ids[j]).strip())
for i in range(restarts):
_tmp = []
for j in range(len(input_ids)):
if j in chose_indices:
_tmp.append(_sl[j][i % len(_sl[j])])
else:
_tmp.append(input_ids[j])
_new_ids.append(_tmp)
# st.markdown(_new_ids)
outputs = [tokenizer.decode(_new_ids[i]).split('</s></s>')[0] for i in range(restarts)]
st.success(f'We found {restarts} replacements!', icon="βœ…")
st.markdown('<br>'.join(outputs), unsafe_allow_html=True)
else:
st.error('At least choose one subword.')