import streamlit as st import os from streamlit_extras.stateful_button import button from transformers import GPT2Tokenizer, GPT2Model tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2Model.from_pretrained('gpt2') import pickle all_keys = pickle.load(open('keys.pkl', 'rb')) all_keys = [i.strip() for i in all_keys] import torch from copy import deepcopy from time import time st.title('Blackbox Attack') def run(model, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?', loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu')): subword_num = model.wte.weight.shape[0] _input = tokenizer([text] * restarts, return_tensors="pt") for k in _input.keys(): _input[k] = _input[k].to(device) ori_output = model(**_input)['last_hidden_state'] ori_embedding = model.wte(_input['input_ids']).detach() ori_embedding.requires_grad = False ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device) noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1], subword_num, requires_grad=True, device=device) ori_output = ori_output.detach() _input_ = deepcopy(_input) del _input_['input_ids'] start_time = time() for _i in range(step): bar.progress((_i + 1) / step) perturbed_embedding = ori_embedding.clone() for i in range(len(noise_mask)): _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i] _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True) perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, model.wte.weight) _input_['inputs_embeds'] = perturbed_embedding outputs_perturbed = model(**_input_)['last_hidden_state'] loss = loss_funt(ori_output, outputs_perturbed) loss.backward() noise.data = (noise.data - lr * noise.grad.detach()) noise.grad.zero_() _bar_text.text(f'{(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left') # validate with torch.no_grad(): perturbed_inputs = deepcopy(_input) for i in range(len(noise_mask)): _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i] _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True) # print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}') perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long() perturbed_questions = [] for i in range(restarts): perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("")[0]) return perturbed_questions from transformers import GPT2Tokenizer, GPT2Model tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2Model.from_pretrained('gpt2') # encoded_input = tokenizer(text, return_tensors='pt') # output = model(**encoded_input) option = st.selectbox( 'Which method you would like to use?', ('GPT-2 (Searching secret languages based on GPT-2)', 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.') ) title = st.text_area('Input text.', 'Which name is also used to describe the Amazon rainforest in English?') if option == 'GPT-2 (Searching secret languages based on GPT-2)': _cols = st.columns(2) restarts = _cols[0].number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d') step = _cols[1].number_input('Step for searching Secret Langauge', value=100, min_value=1, step=1, format='%d') else: restarts = st.number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d') def get_secret_language(title): if ord(title[0]) in list(range(48, 57)): file_name = 'num_dict.pkl' elif ord(title[0]) in list(range(97, 122)) + list(range(65, 90)): file_name = f'{ord(title[0])}_dict.pkl' else: file_name = 'other_dict.pkl' datas = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb')) data_ = datas[title.strip()] _sls_id = [] for i in range(len(data_['secret languages'])): new_ids = tokenizer(data_['replaced sentences'][i])['input_ids'] _sl = data_['secret languages'][i] for _id in new_ids: if _sl.strip() == tokenizer.decode(_id): _sls_id.append(_id) break return _sls_id if button('Tokenize', key='tokenizer'): for key in st.session_state.keys(): if key not in ['tokenizer', 'start'] and 'tokenizer_' not in key: del st.session_state[key] input_ids = tokenizer(title)['input_ids'] st.markdown('## Choose the (sub)words you want to replace.') subwords = [tokenizer.decode(i) for i in input_ids] _len = len(subwords) for i in range(int(_len / 6) + 1): cols = st.columns(6) for j in range(6): with cols[j]: _index = i * 6 + j if _index < _len: disable = False if subwords[_index].strip() not in all_keys and option == 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.': disable = True button(subwords[_index], key=f'tokenizer_{_index}', disabled=disable) # st.markdown(dict(st.session_state)) st.markdown('## Ready to go? Hold on tight.') if button('Give it a shot!', key='start'): chose_indices = [] for key in st.session_state: if st.session_state[key]: if 'tokenizer_' in key: # st.markdown(key) chose_indices.append(int(key.replace('tokenizer_', ''))) if len(chose_indices): _bar_text = st.empty() if option == 'GPT-2 (Searching secret languages based on GPT-2)': bar = st.progress(0) # st.markdown('start') outputs = run(model, _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step) else: _new_ids = [] _sl = {} for j in chose_indices: _sl[j] = get_secret_language(tokenizer.decode(input_ids[j]).strip()) for i in range(restarts): _tmp = [] for j in range(len(input_ids)): if j in chose_indices: _tmp.append(_sl[j][i % len(_sl[j])]) else: _tmp.append(input_ids[j]) _new_ids.append(_tmp) # st.markdown(_new_ids) outputs = [tokenizer.decode(_new_ids[i]).split('')[0] for i in range(restarts)] st.success(f'We found {restarts} replacements!', icon="✅") st.markdown('
'.join(outputs), unsafe_allow_html=True) else: st.error('At least choose one subword.')