import streamlit as st import os from streamlit_extras.stateful_button import button from transformers import GPT2Tokenizer, GPT2Model tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2Model.from_pretrained('gpt2') import pickle all_keys = pickle.load(open('keys.pkl', 'rb')) all_keys = [i.strip() for i in all_keys] import torch from copy import deepcopy from time import time st.title('Blackbox Attack') st.sidebar.markdown('On this page, we offer a tool for generating replacement words using secret languages.') st.sidebar.markdown('There are two methods for generating replacements.') st.sidebar.markdown('1. GPT-2 (Searching secret languages based on GPT-2): this method calculates secret languages using [GPT-2](https://huggingface.co/gpt2) and requires input text, the number of replacements desired, and the steps. The number of replacements represents the number of sentences you want to generate, while steps refer to the steps in the SecretFinding process.') st.sidebar.markdown('2. Use the secret language we found on ALBERT, DistillBERT, and Roberta: this method replaces words directly with the secret language dictionary derived from ALBERT, DistillBERT, and Roberta.') def run(model, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?', loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu')): subword_num = model.wte.weight.shape[0] _input = tokenizer([text] * restarts, return_tensors="pt") for k in _input.keys(): _input[k] = _input[k].to(device) ori_output = model(**_input)['last_hidden_state'] ori_embedding = model.wte(_input['input_ids']).detach() ori_embedding.requires_grad = False ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device) noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1], subword_num, requires_grad=True, device=device) ori_output = ori_output.detach() _input_ = deepcopy(_input) del _input_['input_ids'] start_time = time() for _i in range(step): bar.progress((_i + 1) / step) perturbed_embedding = ori_embedding.clone() for i in range(len(noise_mask)): _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i] _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True) perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, model.wte.weight) _input_['inputs_embeds'] = perturbed_embedding outputs_perturbed = model(**_input_)['last_hidden_state'] loss = loss_funt(ori_output, outputs_perturbed) loss.backward() noise.data = (noise.data - lr * noise.grad.detach()) noise.grad.zero_() _bar_text.text(f'{(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left') # validate with torch.no_grad(): perturbed_inputs = deepcopy(_input) for i in range(len(noise_mask)): _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i] _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True) # print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}') perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long() perturbed_questions = [] for i in range(restarts): perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("")[0]) return perturbed_questions from transformers import GPT2Tokenizer, GPT2Model tokenizer = GPT2Tokenizer.from_pretrained('gpt2') model = GPT2Model.from_pretrained('gpt2') # encoded_input = tokenizer(text, return_tensors='pt') # output = model(**encoded_input) option = st.selectbox( 'Which method you would like to use?', ('GPT-2 (Searching secret languages based on GPT-2)', 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.') ) title = st.text_area('Input text.', 'Which name is also used to describe the Amazon rainforest in English?') if option == 'GPT-2 (Searching secret languages based on GPT-2)': _cols = st.columns(2) restarts = _cols[0].number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d') step = _cols[1].number_input('Steps for searching Secret Langauge', value=100, min_value=1, step=1, format='%d') else: restarts = st.number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d') def get_secret_language(title): if ord(title[0]) in list(range(48, 57)): file_name = 'num_dict.pkl' elif ord(title[0]) in list(range(97, 122)) + list(range(65, 90)): file_name = f'{ord(title[0])}_dict.pkl' else: file_name = 'other_dict.pkl' datas = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb')) data_ = datas[title.strip()] _sls_id = [] for i in range(len(data_['secret languages'])): new_ids = tokenizer(data_['replaced sentences'][i])['input_ids'] _sl = data_['secret languages'][i] for _id in new_ids: if _sl.strip() == tokenizer.decode(_id): _sls_id.append(_id) break return _sls_id if button('Tokenize', key='tokenizer'): for key in st.session_state.keys(): if key not in ['tokenizer', 'start'] and 'tokenizer_' not in key: del st.session_state[key] input_ids = tokenizer(title)['input_ids'] st.markdown('## Choose the (sub)words you want to replace.') subwords = [tokenizer.decode(i) for i in input_ids] _len = len(subwords) for i in range(int(_len / 6) + 1): cols = st.columns(6) for j in range(6): with cols[j]: _index = i * 6 + j if _index < _len: disable = False if subwords[_index].strip() not in all_keys and option == 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.': disable = True button(subwords[_index], key=f'tokenizer_{_index}', disabled=disable) # st.markdown(dict(st.session_state)) st.markdown('## Ready to go? Hold on tight.') if button('Give it a shot!', key='start'): chose_indices = [] for key in st.session_state: if st.session_state[key]: if 'tokenizer_' in key: # st.markdown(key) chose_indices.append(int(key.replace('tokenizer_', ''))) if len(chose_indices): _bar_text = st.empty() if option == 'GPT-2 (Searching secret languages based on GPT-2)': bar = st.progress(0) # st.markdown('start') outputs = run(model, _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step) else: _new_ids = [] _sl = {} for j in chose_indices: _sl[j] = get_secret_language(tokenizer.decode(input_ids[j]).strip()) for i in range(restarts): _tmp = [] for j in range(len(input_ids)): if j in chose_indices: _tmp.append(_sl[j][i % len(_sl[j])]) else: _tmp.append(input_ids[j]) _new_ids.append(_tmp) # st.markdown(_new_ids) outputs = [tokenizer.decode(_new_ids[i]).split('')[0] for i in range(restarts)] st.success(f'We found {restarts} replacements!', icon="✅") st.markdown('
'.join(outputs), unsafe_allow_html=True) else: st.error('At least choose one subword.')