Immortalise commited on
Commit
1c79925
1 Parent(s): 3429aba
adv_prompts/chatgpt_fewshot.md ADDED
The diff for this file is too large to render. See raw diff
 
adv_prompts/chatgpt_zeroshot.md ADDED
The diff for this file is too large to render. See raw diff
 
adv_prompts/t5_fewshot.md ADDED
The diff for this file is too large to render. See raw diff
 
adv_prompts/t5_zeroshot.md ADDED
The diff for this file is too large to render. See raw diff
 
adv_prompts/ul2_fewshot.md ADDED
The diff for this file is too large to render. See raw diff
 
adv_prompts/ul2_zeroshot.md ADDED
The diff for this file is too large to render. See raw diff
 
adv_prompts/vicuna_fewshot.md ADDED
The diff for this file is too large to render. See raw diff
 
adv_prompts/vicuna_zeroshot.md ADDED
The diff for this file is too large to render. See raw diff
 
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from parse import retrieve
3
+
4
+
5
+ def main():
6
+ st.title("Streamlit App")
7
+
8
+ model_name = st.selectbox(
9
+ "Select Model",
10
+ options=["T5", "Vicuna", "UL2", "ChatGPT"],
11
+ index=0,
12
+ )
13
+
14
+ dataset_name = st.selectbox(
15
+ "Select Dataset",
16
+ options=[
17
+ "SST-2", "CoLA", "QQP", "MRPC", "MNLI", "QNLI",
18
+ "RTE", "WNLI", "MMLU", "SQuAD V2", "IWSLT 2017", "UN Multi", "Math"
19
+ ],
20
+ index=0,
21
+ )
22
+
23
+ attack_name = st.selectbox(
24
+ "Select Attack",
25
+ options=[
26
+ "BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic"
27
+ ],
28
+ index=0,
29
+ )
30
+
31
+ prompt_type = st.selectbox(
32
+ "Select Prompt Type",
33
+ options=["zeroshot-task", "zeroshot-role", "fewshot-task", "fewshot-role"],
34
+ index=0,
35
+ )
36
+
37
+ st.write(f"Model: {model_name}")
38
+ st.write(f"Dataset: {dataset_name}")
39
+ st.write(f"Prompt Type: {prompt_type}")
40
+
41
+ if st.button("Retrieve"):
42
+ output = retrieve(model_name, dataset_name, attack_name, prompt_type)
43
+ st.write(f"Output: {output}")
44
+
45
+ if __name__ == "__main__":
46
+ main()
parse.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import re
3
+
4
+
5
+ def split_markdown_by_title(markdown_file):
6
+ with open(markdown_file, 'r', encoding='utf-8') as f:
7
+ content = f.read()
8
+
9
+ re_str = "# cola|# mnli|# mrpc|# qnli|# qqp|# rte|# sst2|# wnli|# mmlu|# squad_v2|# iwslt|# un_multi|# math"
10
+
11
+ datasets = ["# cola", "# mnli", "# mrpc", "# qnli", "# qqp", "# rte", "# sst2", "# wnli",
12
+ "# mmlu", "# squad_v2", "# iwslt", "# un_multi", "# math"]
13
+
14
+ # re_str = "# cola|# mnli|# mrpc|# qnli|# qqp|# rte|# sst2|# wnli"
15
+ # datasets = ["# cola", "# mnli", "# mrpc", "# qnli", "# qqp", "# rte", "# sst2", "# wnli"]
16
+ primary_sections = re.split(re_str, content)[1:]
17
+ assert len(primary_sections) == len(datasets)
18
+
19
+ all_sections_dict = {}
20
+
21
+ for dataset, primary_section in zip(datasets, primary_sections):
22
+ re_str = "## "
23
+ results = re.split(re_str, primary_section)
24
+ keywords = ["10 prompts", "bertattack", "checklist", "deepwordbug", "stresstest",
25
+ "textfooler", "textbugger", "translation"]
26
+
27
+ secondary_sections_dict = {}
28
+ for res in results:
29
+ for keyword in keywords:
30
+ if keyword in res.lower():
31
+ secondary_sections_dict[keyword] = res
32
+ break
33
+
34
+ all_sections_dict[dataset] = secondary_sections_dict
35
+
36
+ return all_sections_dict
37
+ # def prompts_understanding(sections_dict):
38
+ # for dataset in sections_dict.keys():
39
+ # # print(dataset)
40
+ # for title in sections_dict[dataset].keys():
41
+ # if title == "10 prompts":
42
+ # prompts = sections_dict[dataset][title].split("\n")
43
+ # num = 0
44
+ # task_prompts_acc = []
45
+ # role_prompts_acc = []
46
+ # for prompt in prompts:
47
+ # if "Acc: " not in prompt:
48
+ # continue
49
+ # else:
50
+ # import re
51
+ # num += 1
52
+ # match = re.search(r'Acc: (\d+\.\d+)%', prompt)
53
+ # if match:
54
+ # number = float(match.group(1))
55
+ # if num <= 10:
56
+ # task_prompts_acc.append(number)
57
+ # else:
58
+ # role_prompts_acc.append(number)
59
+
60
+ # print(task_prompts_acc)
61
+ # print(role_prompts_acc)
62
+ import os
63
+ def list_files(directory):
64
+ files = [os.path.join(directory, d) for d in os.listdir(directory) if not os.path.isdir(os.path.join(directory, d))]
65
+ return files
66
+
67
+ def convert_model_name(attack):
68
+ attack_name = {
69
+ "T5": "t5",
70
+ "UL2": "ul2",
71
+ "Vicuna": "vicuna",
72
+ "ChatGPT": "chatgpt",
73
+ }
74
+ return attack_name[attack]
75
+
76
+ def convert_attack_name(attack):
77
+ attack_name = {
78
+ "BertAttack": "bertattack",
79
+ "CheckList": "checklist",
80
+ "DeepWordBug": "deepwordbug",
81
+ "StressTest": "stresstest",
82
+ "TextFooler": "textfooler",
83
+ "TextBugger": "textbugger",
84
+ "Semantic": "translation",
85
+ }
86
+ return attack_name[attack]
87
+
88
+ def convert_dataset_name(dataset):
89
+ dataset_name = {
90
+ "CoLA": "# cola",
91
+ "MNLI": "# mnli",
92
+ "MRPC": "# mrpc",
93
+ "QNLI": "# qnli",
94
+ "QQP": "# qqp",
95
+ "RTE": "# rte",
96
+ "SST-2": "# sst2",
97
+ "WNLI": "# wnli",
98
+ "MMLU": "# mmlu",
99
+ "SQuAD V2": "# squad_v2",
100
+ "IWSLT": "# iwslt",
101
+ "UN Multi": "# un_multi",
102
+ "Math": "# math",
103
+ "Avg": "Avg",
104
+ }
105
+ return dataset_name[dataset]
106
+
107
+
108
+ def retrieve(model_name, dataset_name, attack_name, prompt_type):
109
+ model_name = convert_model_name(model_name)
110
+ dataset_name = convert_dataset_name(dataset_name)
111
+ attack_name = convert_attack_name(attack_name)
112
+
113
+ if "zero" in prompt_type:
114
+ shot = "zeroshot"
115
+ else:
116
+ shot = "fewshot"
117
+
118
+ if "task" in prompt_type:
119
+ prompt_type = "task"
120
+ else:
121
+ prompt_type = "role"
122
+
123
+ directory_path = "./db"
124
+ md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
125
+ sections_dict = split_markdown_by_title(md_dir)
126
+
127
+ for cur_dataset in sections_dict.keys():
128
+ if cur_dataset == dataset_name:
129
+ dataset_dict = sections_dict[cur_dataset]
130
+ for cur_attack in dataset_dict.keys():
131
+ if cur_attack == attack_name:
132
+ pass
133
+
134
+ if attack_name == "translation":
135
+ results = dataset_dict[attack_name].split("\n")
136
+
137
+ atk_acc = []
138
+
139
+ for result in results:
140
+ if "acc: " not in result:
141
+ continue
142
+
143
+ import re
144
+
145
+ match_atk = re.search(r'acc: (\d+\.\d+)%', result)
146
+
147
+ number_atk = float(match_atk.group(1))
148
+ atk_acc.append(number_atk)
149
+
150
+ sorted_atk_acc = sorted(atk_acc)[:6]
151
+
152
+ elif title in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]:
153
+
154
+ results = sections_dict[dataset][title].split("Original prompt: ")
155
+ num = 0
156
+
157
+
158
+ for result in results:
159
+ if "Attacked prompt: " not in result:
160
+ continue
161
+ num += 1
162
+ import re
163
+ match_origin = re.search(r'Original acc: (\d+\.\d+)%', result)
164
+ match_atk = re.search(r'attacked acc: (\d+\.\d+)%', result)
165
+ if match_origin and match_atk:
166
+ number_origin = float(match_origin.group(1))
167
+ number_atk = float(match_atk.group(1))
168
+ summary[title][dataset].append((number_origin - number_atk)/number_origin)
169
+ summary[title]["Avg"].append((number_origin - number_atk)/number_origin)
170
+
171
+ # print(model_shot, dataset, title, len(summary[attack][dataset]), num)
172
+
173
+ # for atk in summary.keys():
174
+ # for dataset in summary[atk].keys():
175
+ # # if atk == "translation":
176
+ # print(atk, dataset, len(summary[atk][dataset]))
177
+ # # print(summary[atk][dataset][:10])
178
+
179
+ output_dict = {}
180
+
181
+ sorted_atk_name = ["TextBugger", "DeepWordBug", "TextFooler", "BertAttack", "CheckList", "StressTest", "Semantic"]
182
+ sorted_dataset_name = ["SST-2", "CoLA", "QQP", "MRPC", "MNLI", "QNLI", "RTE", "WNLI", "MMLU", "SQuAD V2", "IWSLT", "UN Multi", "Math"]
183
+
184
+ for atk in sorted_atk_name:
185
+ output_dict[atk] = {}
186
+ for dataset in sorted_dataset_name:
187
+ output_dict[atk][dataset] = ""
188
+
189
+ for sorted_atk in sorted_atk_name:
190
+ for attack, dataset_drop_rates in summary.items():
191
+ # attack = convert_attack_name(attack)
192
+ if convert_attack_name(attack) == sorted_atk:
193
+ for sorted_dataset in sorted_dataset_name:
194
+ for dataset, drop_rates in dataset_drop_rates.items():
195
+ if convert_dataset_name(dataset) == sorted_dataset:
196
+ if len(drop_rates) > 0:
197
+ output_dict[sorted_atk][sorted_dataset] = "{:.2f}".format(sum(drop_rates)/len(drop_rates)) + "\scriptsize{$\pm$" + "{:.2f}".format(np.std(drop_rates)) + "}"
198
+ else:
199
+ output_dict[sorted_atk][sorted_dataset] = "-"
200
+
201
+ total_drop_rate = summary[attack]["Avg"]
202
+ output_dict[sorted_atk]["Avg"] = "{:.2f}".format(np.mean(total_drop_rate)) + "\scriptsize{$\pm$" + "{:.2f}".format(np.std(total_drop_rate)) + "}"
203
+