anonymousauthors commited on
Commit
91b1515
β€’
1 Parent(s): 5a5f4eb

Upload 4 files

Browse files
keys.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:701fae656585df909b98fdcfa4ec681754a99dd9a16eb07559a948725ba25b73
3
+ size 187154
pages/1_πŸ“–_Dictionary_(Browse).py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ # import gdown
4
+ import os
5
+ import pickle
6
+ from streamlit import session_state as _state
7
+
8
+ from PyDictionary import PyDictionary
9
+ from streamlit_extras.colored_header import colored_header
10
+
11
+ dictionary = PyDictionary()
12
+
13
+ st.set_page_config(layout="wide", page_title="ACl23 Secret Language")
14
+ from streamlit_extras.switch_page_button import switch_page
15
+ from streamlit_extras.stateful_button import button
16
+ # for key in st.session_state.keys():
17
+ # del st.session_state[key]
18
+ # st.markdown(st.session_state)
19
+ # want_to_contribute = st.button("I want to contribute!")
20
+ # if want_to_contribute:
21
+ # # switch_page("dictionary (search)")
22
+ # switch_page("home")
23
+
24
+ # st.title("ACl23 Secret Language")
25
+
26
+ # sidebar
27
+ st.sidebar.header("πŸ“™ Dictionary (Browse)")
28
+ # title = st.sidebar.text_input(":red[Search secret languages given the following word]", 'Asian')
29
+ buttons = {}
30
+
31
+ # def call_back():
32
+ # for k in buttons:
33
+ # if buttons[k]:
34
+ # st.text(k)
35
+
36
+ with st.sidebar:
37
+ cols0 = st.columns(8)
38
+ for i in range(len(cols0)):
39
+ with cols0[i]:
40
+ buttons[chr(65 + i)] = st.button(chr(65 + i))
41
+ # buttons[chr(65 + i)] = button(chr(65 + i), key=chr(65 + i))
42
+ cols1 = st.columns(8)
43
+ for i in range(len(cols1)):
44
+ with cols1[i]:
45
+ buttons[chr(65 + 8 + i)] = st.button(chr(65 + 8 + i))
46
+ # buttons[chr(65 + 8 + i)] = button(chr(65 + 8 + i), key=chr(65 + 8 + i))
47
+ cols2 = st.columns(8)
48
+ for i in range(len(cols2)):
49
+ with cols2[i]:
50
+ buttons[chr(65 + 16 + i)] = st.button(chr(65 + 16 + i))
51
+ # buttons[chr(65 + 16 + i)] = button(chr(65 + 16 + i), key=chr(65 + 16 + i))
52
+ cols3 = st.columns(8)
53
+ for i in range(2):
54
+ with cols3[i]:
55
+ buttons[chr(65 + 24 + i)] = st.button(chr(65 + 24 + i))
56
+ # buttons[chr(65 + 24 + i)] = button(chr(65 + 24 + i), key=chr(65 + 24 + i))
57
+ cols4 = st.columns(2)
58
+ buttons['0-9'] = cols4[0].button('0-9')
59
+ buttons['Others'] = cols4[1].button('Others')
60
+ # select = st.radio(
61
+ # "Select initial to browse.",
62
+ # [chr(i) for i in range(97, 123)] + ['0-9', 'Others'],
63
+ # key="nokeyjustselect",
64
+ # )
65
+
66
+ # if select == '0-9':
67
+ # st.title(select)
68
+ # file_names = ['num_dict.pkl']
69
+ # elif select == 'Others':
70
+ # st.title(select)
71
+ # file_names = ['other_dict.pkl']
72
+ # elif ord(select[0]) in list(range(97, 123)) + list(range(65, 91)):
73
+ # st.title(chr(ord(select)))
74
+ # file_names = [f'{ord(select[0]) - 32}_dict.pkl', f'{ord(select[0])}_dict.pkl']
75
+ # all_data = {}
76
+ # all_key = []
77
+ # for file_name in file_names:
78
+ # _data = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb'))
79
+ # all_data.update(_data)
80
+ # all_key.extend(sorted(list(_data.keys())))
81
+ # # st.markdown(file_name, unsafe_allow_html=True)
82
+ # # st.markdown(_data.keys(), unsafe_allow_html=True)
83
+
84
+ # all_key = sorted(list(set(all_key)))
85
+ # # st.markdown(','.join(all_key))
86
+ # for key in all_key:
87
+ # # if len(key) and key[0] != '"':
88
+ # # st.markdown(key, unsafe_allow_html=True)
89
+ # # st.change_page("home")
90
+ # # st.markdown(f'<a href="Dictionary_(Search)?word={key}" target="_self">{key}</a>', unsafe_allow_html=True)
91
+ # # print(key)
92
+ # # word_buttons.append(st.button(f'{key}', key=key))
93
+ # _button = st.button(f'{key}', key=key)
94
+ # if _button:
95
+ # st.session_state.click_word = key
96
+ # switch_page("dictionary (search)")
97
+
98
+
99
+ # st.markdown(st.session_state)
100
+ all_word_button = {}
101
+ # for key in all_word_button:
102
+ # if all_word_button[key]:
103
+ # st.session_state.click_word = key
104
+ # switch_page("dictionary (search)")
105
+ for k in st.session_state.keys():
106
+ if st.session_state[k]:
107
+ if 'button_' in k:
108
+ st.session_state.click_word = k.split('button_')[-1]
109
+ # st.markdown(k)
110
+ switch_page("dictionary (search)")
111
+
112
+ # all_condition = False
113
+ word_buttons = None
114
+
115
+ for k in buttons:
116
+ if buttons[k]:
117
+ # for _k in buttons:
118
+ # if _k != k:
119
+ # st.session_state[_k] = False
120
+
121
+ word_buttons = []
122
+ if k == '0-9':
123
+ # st.title(k)
124
+ colored_header(
125
+ label=k,
126
+ description="",
127
+ color_name="violet-70",
128
+ )
129
+ file_names = ['num_dict.pkl']
130
+ elif k == 'Others':
131
+ # st.title(k)
132
+ colored_header(
133
+ label=k,
134
+ description="",
135
+ color_name="violet-70",
136
+ )
137
+ file_names = ['other_dict.pkl']
138
+ elif ord(k[0]) in list(range(97, 123)) + list(range(65, 91)):
139
+ # st.title(chr(ord(k)))
140
+ colored_header(
141
+ label=chr(ord(k)),
142
+ description="",
143
+ color_name="violet-70",
144
+ )
145
+ file_names = [f'{ord(k[0]) + 32}_dict.pkl', f'{ord(k[0])}_dict.pkl']
146
+ all_data = {}
147
+ all_key = []
148
+ for file_name in file_names:
149
+ _data = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb'))
150
+ all_data.update(_data)
151
+ all_key.extend(sorted(list(_data.keys())))
152
+ # st.markdown(file_name, unsafe_allow_html=True)
153
+ # st.markdown(_data.keys(), unsafe_allow_html=True)
154
+
155
+ all_key = sorted(list(set(all_key)))
156
+ # st.markdown(','.join(all_key))
157
+ for key in all_key:
158
+ # if len(key) and key[0] != '"':
159
+ # st.markdown(key, unsafe_allow_html=True)
160
+ # st.change_page("home")
161
+ # st.markdown(f'<a href="Dictionary_(Search)?word={key}" target="_self">{key}</a>', unsafe_allow_html=True)
162
+ # print(key)
163
+ # word_buttons.append(st.button(f'{key}', key=key))
164
+ all_word_button[key] = st.button(f'{key}', key=f'button_{key}')
165
+ # all_word_button[key] = button(f'{key}', key=key)
166
+
167
+
168
+
169
+ # for _button in word_buttons:
170
+ # if _button:
171
+ # # st.session_state.click_word = key
172
+ # # st.markdown('asdfd')
173
+ # switch_page("home")
174
+ # with st.expander(key):
175
+ # st.markdown(':red[Secret Languages:' + ','.join(all_data[key]['secret languages']), unsafe_allow_html=True)
pages/2_😈_Blackbox_Attack.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from streamlit_extras.stateful_button import button
4
+
5
+ from transformers import GPT2Tokenizer, GPT2Model
6
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
7
+ model = GPT2Model.from_pretrained('gpt2')
8
+ import pickle
9
+ all_keys = pickle.load(open('keys.pkl', 'rb'))
10
+ all_keys = [i.strip() for i in all_keys]
11
+ import torch
12
+ from copy import deepcopy
13
+ from time import time
14
+ st.title('Blackbox Attack')
15
+ def run(model, _bar_text=None, bar=None, text='Which name is also used to describe the Amazon rainforest in English?', loss_funt=torch.nn.MSELoss(), lr=1, noise_mask=[1,2], restarts=10, step=100, device = torch.device('cpu')):
16
+ subword_num = model.wte.weight.shape[0]
17
+
18
+ _input = tokenizer([text] * restarts, return_tensors="pt")
19
+ for k in _input.keys():
20
+ _input[k] = _input[k].to(device)
21
+
22
+ ori_output = model(**_input)['last_hidden_state']
23
+
24
+ ori_embedding = model.wte(_input['input_ids']).detach()
25
+ ori_embedding.requires_grad = False
26
+ ori_word_one_hot = torch.nn.functional.one_hot(_input['input_ids'].detach(), num_classes=subword_num).to(device)
27
+
28
+ noise = torch.randn(ori_embedding.shape[0], ori_embedding.shape[1],
29
+ subword_num, requires_grad=True, device=device)
30
+ ori_output = ori_output.detach()
31
+ _input_ = deepcopy(_input)
32
+ del _input_['input_ids']
33
+
34
+ start_time = time()
35
+ for _i in range(step):
36
+ bar.progress((_i + 1) / step)
37
+ perturbed_embedding = ori_embedding.clone()
38
+ for i in range(len(noise_mask)):
39
+ _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
40
+ _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
41
+ perturbed_embedding[:, noise_mask[i]] = torch.matmul(_tmp_perturbed_input, model.wte.weight)
42
+
43
+ _input_['inputs_embeds'] = perturbed_embedding
44
+ outputs_perturbed = model(**_input_)['last_hidden_state']
45
+
46
+ loss = loss_funt(ori_output, outputs_perturbed)
47
+ loss.backward()
48
+ noise.data = (noise.data - lr * noise.grad.detach())
49
+ noise.grad.zero_()
50
+ _bar_text.text(f'{(time() - start_time) * (step - _i - 1) / (_i + 1):.2f} seconds left')
51
+ # validate
52
+ with torch.no_grad():
53
+ perturbed_inputs = deepcopy(_input)
54
+ for i in range(len(noise_mask)):
55
+ _tmp_perturbed_input = ori_word_one_hot[:, noise_mask[i]] + noise[:, i]
56
+ _tmp_perturbed_input /= _tmp_perturbed_input.sum(-1, keepdim=True)
57
+ # print(f'torch.argmax(_tmp_perturbed_input, dim=-1).long(){torch.argmax(_tmp_perturbed_input, dim=-1).long()}')
58
+ perturbed_inputs['input_ids'][:, noise_mask[i]] = torch.argmax(_tmp_perturbed_input, dim=-1).long()
59
+ perturbed_questions = []
60
+ for i in range(restarts):
61
+ perturbed_questions.append(tokenizer.decode(perturbed_inputs["input_ids"][i]).split("</s></s>")[0])
62
+ return perturbed_questions
63
+
64
+
65
+ from transformers import GPT2Tokenizer, GPT2Model
66
+ tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
67
+ model = GPT2Model.from_pretrained('gpt2')
68
+ # encoded_input = tokenizer(text, return_tensors='pt')
69
+ # output = model(**encoded_input)
70
+
71
+ option = st.selectbox(
72
+ 'Which method you would like to use?',
73
+ ('GPT-2 (Searching secret languages based on GPT-2)', 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.')
74
+ )
75
+
76
+ title = st.text_area('Input text.', 'Which name is also used to describe the Amazon rainforest in English?')
77
+
78
+ if option == 'GPT-2 (Searching secret languages based on GPT-2)':
79
+ _cols = st.columns(2)
80
+ restarts = _cols[0].number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
81
+ step = _cols[1].number_input('Step for searching Secret Langauge', value=100, min_value=1, step=1, format='%d')
82
+ else:
83
+ restarts = st.number_input('Number of replacements.', value=10, min_value=1, step=1, format='%d')
84
+
85
+ def get_secret_language(title):
86
+ if ord(title[0]) in list(range(48, 57)):
87
+ file_name = 'num_dict.pkl'
88
+ elif ord(title[0]) in list(range(97, 122)) + list(range(65, 90)):
89
+ file_name = f'{ord(title[0])}_dict.pkl'
90
+ else:
91
+ file_name = 'other_dict.pkl'
92
+ datas = pickle.load(open(f'all_secret_langauge_by_fist/{file_name}', 'rb'))
93
+ data_ = datas[title.strip()]
94
+
95
+ _sls_id = []
96
+ for i in range(len(data_['secret languages'])):
97
+ new_ids = tokenizer(data_['replaced sentences'][i])['input_ids']
98
+ _sl = data_['secret languages'][i]
99
+ for _id in new_ids:
100
+ if _sl.strip() == tokenizer.decode(_id):
101
+ _sls_id.append(_id)
102
+ break
103
+ return _sls_id
104
+
105
+ if button('Tokenize', key='tokenizer'):
106
+ for key in st.session_state.keys():
107
+ if key not in ['tokenizer', 'start'] and 'tokenizer_' not in key:
108
+ del st.session_state[key]
109
+ input_ids = tokenizer(title)['input_ids']
110
+ st.markdown('## Choose the (sub)words you want to replace.')
111
+ subwords = [tokenizer.decode(i) for i in input_ids]
112
+ _len = len(subwords)
113
+ for i in range(int(_len / 6) + 1):
114
+ cols = st.columns(6)
115
+ for j in range(6):
116
+ with cols[j]:
117
+ _index = i * 6 + j
118
+ if _index < _len:
119
+ disable = False
120
+ if subwords[_index].strip() not in all_keys and option == 'Use the secret language we found on ALBERT, DistillBERT, and Roberta.':
121
+ disable = True
122
+ button(subwords[_index], key=f'tokenizer_{_index}', disabled=disable)
123
+
124
+
125
+ # st.markdown(dict(st.session_state))
126
+ st.markdown('## Ready to go? Hold on tight.')
127
+ if button('Give it a shot!', key='start'):
128
+ chose_indices = []
129
+ for key in st.session_state:
130
+ if st.session_state[key]:
131
+ if 'tokenizer_' in key:
132
+ # st.markdown(key)
133
+ chose_indices.append(int(key.replace('tokenizer_', '')))
134
+ if len(chose_indices):
135
+ _bar_text = st.empty()
136
+ if option == 'GPT-2 (Searching secret languages based on GPT-2)':
137
+ bar = st.progress(0)
138
+ # st.markdown('start')
139
+ outputs = run(model, _bar_text=_bar_text, bar=bar, text=title, noise_mask=chose_indices, restarts=restarts, step=step)
140
+ else:
141
+ _new_ids = []
142
+ _sl = {}
143
+ for j in chose_indices:
144
+ _sl[j] = get_secret_language(tokenizer.decode(input_ids[j]).strip())
145
+ for i in range(restarts):
146
+ _tmp = []
147
+ for j in range(len(input_ids)):
148
+ if j in chose_indices:
149
+ _tmp.append(_sl[j][i % len(_sl[j])])
150
+ else:
151
+ _tmp.append(input_ids[j])
152
+ _new_ids.append(_tmp)
153
+ # st.markdown(_new_ids)
154
+ outputs = [tokenizer.decode(_new_ids[i]).split('</s></s>')[0] for i in range(restarts)]
155
+
156
+ st.success(f'We found {restarts} replacements!', icon="βœ…")
157
+ st.markdown('<br>'.join(outputs), unsafe_allow_html=True)
158
+ else:
159
+ st.error('At least choose one subword.')
160
+
161
+
162
+