File size: 7,195 Bytes
91b1515
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import streamlit as st
import os 
from streamlit_extras.stateful_button import button

from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
import pickle
all_keys = pickle.load(open('keys.pkl', 'rb'))
all_keys = [i.strip() for i in all_keys]
import torch
from copy import deepcopy
from time import time
st.title('Blackbox Attack')
def run(model, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?', loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu')):
    subword_num = model.wte.weight.shape[0]
    
    _input = tokenizer([text] * restarts, return_tensors="pt")
    for k in _input.keys():
        _input[k] = _input[k].to(device)

    ori_output = model(**_input)['last_hidden_state']

    ori_embedding = model.wte(_input['input_ids']).detach()
    ori_embedding.requires_grad = False
    ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)

    noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
                        subword_num, requires_grad=True, device=device)
    ori_output = ori_output.detach()
    _input_ = deepcopy(_input)
    del _input_['input_ids']

    start_time = time()
    for _i in range(step):
        bar.progress((_i + 1) / step)
        perturbed_embedding = ori_embedding.clone()
        for i in range(len(noise_mask)):
            _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
            _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
            perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, model.wte.weight)

        _input_['inputs_embeds'] = perturbed_embedding
        outputs_perturbed = model(**_input_)['last_hidden_state']

        loss = loss_funt(ori_output, outputs_perturbed)
        loss.backward()
        noise.data = (noise.data - lr * noise.grad.detach())
        noise.grad.zero_()
        _bar_text.text(f'{(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left')
    # validate
    with torch.no_grad():
        perturbed_inputs = deepcopy(_input)
        for i in range(len(noise_mask)):
            _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
            _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
            # print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
            perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
        perturbed_questions = []
        for i in range(restarts):
            perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("</s></s>")[0])
    return perturbed_questions


from transformers import GPT2Tokenizer, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
# encoded_input = tokenizer(text, return_tensors='pt')
# output = model(**encoded_input)

option = st.selectbox(
    'Which method you would like to use?',
    ('GPT-2 (Searching secret languages based on GPT-2)', 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.')
)

title = st.text_area('Input text.', 'Which name is also used to describe the Amazon rainforest in English?')

if option == 'GPT-2 (Searching secret languages based on GPT-2)':
    _cols = st.columns(2)
    restarts = _cols[0].number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
    step = _cols[1].number_input('Step for searching Secret Langauge', value=100, min_value=1, step=1, format='%d')
else:
    restarts = st.number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')

def get_secret_language(title):
    if ord(title[0]) in list(range(48, 57)):
        file_name = 'num_dict.pkl'
    elif ord(title[0]) in list(range(97, 122)) + list(range(65, 90)):
        file_name = f'{ord(title[0])}_dict.pkl'
    else:
        file_name = 'other_dict.pkl'
    datas = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb'))
    data_ = datas[title.strip()]

    _sls_id = []
    for i in range(len(data_['secret languages'])):
        new_ids = tokenizer(data_['replaced sentences'][i])['input_ids']
        _sl = data_['secret languages'][i]
        for _id in new_ids:
            if _sl.strip() == tokenizer.decode(_id):
                _sls_id.append(_id)
                break
    return _sls_id

if button('Tokenize', key='tokenizer'):
    for key in st.session_state.keys():
        if key not in ['tokenizer', 'start'] and 'tokenizer_' not in key:
            del st.session_state[key]
    input_ids = tokenizer(title)['input_ids']
    st.markdown('## Choose the (sub)words you want to replace.')
    subwords = [tokenizer.decode(i) for i in input_ids]
    _len = len(subwords)
    for i in range(int(_len / 6) + 1):
        cols = st.columns(6)
        for j in range(6):
            with cols[j]:
                _index = i * 6 + j
                if _index < _len:
                    disable = False
                    if subwords[_index].strip() not in all_keys and option == 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.':
                        disable = True
                    button(subwords[_index], key=f'tokenizer_{_index}', disabled=disable)

    
    # st.markdown(dict(st.session_state))
    st.markdown('## Ready to go? Hold on tight.')
    if button('Give it a shot!', key='start'):
        chose_indices = []
        for key in st.session_state:
            if st.session_state[key]:
                if 'tokenizer_' in key:
                    # st.markdown(key)
                    chose_indices.append(int(key.replace('tokenizer_', '')))
        if len(chose_indices):
            _bar_text = st.empty()
            if option == 'GPT-2 (Searching secret languages based on GPT-2)':
                bar = st.progress(0)
            # st.markdown('start')
                outputs = run(model, _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step)
            else:
                _new_ids = []
                _sl = {}
                for j in chose_indices:
                    _sl[j] = get_secret_language(tokenizer.decode(input_ids[j]).strip())
                for i in range(restarts):
                    _tmp = []
                    for j in range(len(input_ids)):
                        if j in chose_indices:
                            _tmp.append(_sl[j][i % len(_sl[j])])
                        else:
                            _tmp.append(input_ids[j])
                    _new_ids.append(_tmp)
                # st.markdown(_new_ids)
                outputs = [tokenizer.decode(_new_ids[i]).split('</s></s>')[0] for i in range(restarts)]

            st.success(f'We found {restarts} replacements!', icon="βœ…")
            st.markdown('<br>'.join(outputs), unsafe_allow_html=True)
        else:
            st.error('At least choose one subword.')