Immortalise commited on
Commit
646e829
1 Parent(s): 505d6d4
Files changed (3) hide show
  1. __pycache__/parse.cpython-38.pyc +0 -0
  2. app.py +7 -2
  3. parse.py +50 -42
__pycache__/parse.cpython-38.pyc ADDED
Binary file (4.43 kB). View file
 
app.py CHANGED
@@ -39,8 +39,13 @@ def main():
39
  st.write(f"Prompt Type: {prompt_type}")
40
 
41
  if st.button("Retrieve"):
42
- output = retrieve(model_name, dataset_name, attack_name, prompt_type)
43
- st.write(f"Output: {output}")
 
 
 
 
 
44
 
45
  if __name__ == "__main__":
46
  main()
 
39
  st.write(f"Prompt Type: {prompt_type}")
40
 
41
  if st.button("Retrieve"):
42
+ results = retrieve(model_name, dataset_name, attack_name, prompt_type)
43
+
44
+ for result in results:
45
+ st.write("Original prompt: {}".format(result["origin prompt"]))
46
+ st.write("Original acc: {}".format(result["origin acc"]))
47
+ st.write("Attack prompt: {}".format(result["attack prompt"]))
48
+ st.write("Attack acc: {}".format(result["attack acc"]))
49
 
50
  if __name__ == "__main__":
51
  main()
parse.py CHANGED
@@ -120,13 +120,32 @@ def retrieve(model_name, dataset_name, attack_name, prompt_type):
120
  else:
121
  prompt_type = "role"
122
 
123
- directory_path = "./db"
124
  md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
125
  sections_dict = split_markdown_by_title(md_dir)
126
  results = {}
127
  for cur_dataset in sections_dict.keys():
128
  if cur_dataset == dataset_name:
129
  dataset_dict = sections_dict[cur_dataset]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  for cur_attack in dataset_dict.keys():
131
 
132
  if cur_attack == attack_name:
@@ -142,63 +161,52 @@ def retrieve(model_name, dataset_name, attack_name, prompt_type):
142
 
143
  import re
144
 
145
- match_atk = re.search(r'acc: (\d+\.\d+)%', result)
146
  number_atk = float(match_atk.group(1))
147
  results[prompt] = number_atk
148
 
149
  sorted_results = sorted(results.items(), key=lambda item: item[1])[:6]
150
-
151
 
152
- return sorted_results
 
 
 
 
153
 
154
  elif attack_name in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]:
155
 
156
- prompts_dict = dataset_dict[attack_name].split("\n")
157
  num = 0
158
 
159
-
160
  for prompt_summary in prompts_dict:
161
  if "Attacked prompt: " not in prompt_summary:
162
  continue
163
 
 
 
 
 
 
 
 
164
  num += 1
165
  import re
166
  match_origin = re.search(r'Original acc: (\d+\.\d+)%', prompt_summary)
167
  match_atk = re.search(r'attacked acc: (\d+\.\d+)%', prompt_summary)
168
- if match_origin and match_atk:
169
- number_origin = float(match_origin.group(1))
170
- number_atk = float(match_atk.group(1))
171
-
172
- # print(model_shot, dataset, title, len(summary[attack][dataset]), num)
173
-
174
- # for atk in summary.keys():
175
- # for dataset in summary[atk].keys():
176
- # # if atk == "translation":
177
- # print(atk, dataset, len(summary[atk][dataset]))
178
- # # print(summary[atk][dataset][:10])
179
-
180
- output_dict = {}
181
 
182
- sorted_atk_name = ["TextBugger", "DeepWordBug", "TextFooler", "BertAttack", "CheckList", "StressTest", "Semantic"]
183
- sorted_dataset_name = ["SST-2", "CoLA", "QQP", "MRPC", "MNLI", "QNLI", "RTE", "WNLI", "MMLU", "SQuAD V2", "IWSLT", "UN Multi", "Math"]
184
-
185
- for atk in sorted_atk_name:
186
- output_dict[atk] = {}
187
- for dataset in sorted_dataset_name:
188
- output_dict[atk][dataset] = ""
189
-
190
- for sorted_atk in sorted_atk_name:
191
- for attack, dataset_drop_rates in summary.items():
192
- # attack = convert_attack_name(attack)
193
- if convert_attack_name(attack) == sorted_atk:
194
- for sorted_dataset in sorted_dataset_name:
195
- for dataset, drop_rates in dataset_drop_rates.items():
196
- if convert_dataset_name(dataset) == sorted_dataset:
197
- if len(drop_rates) > 0:
198
- output_dict[sorted_atk][sorted_dataset] = "{:.2f}".format(sum(drop_rates)/len(drop_rates)) + "\scriptsize{$\pm$" + "{:.2f}".format(np.std(drop_rates)) + "}"
199
- else:
200
- output_dict[sorted_atk][sorted_dataset] = "-"
201
-
202
- total_drop_rate = summary[attack]["Avg"]
203
- output_dict[sorted_atk]["Avg"] = "{:.2f}".format(np.mean(total_drop_rate)) + "\scriptsize{$\pm$" + "{:.2f}".format(np.std(total_drop_rate)) + "}"
204
-
 
120
  else:
121
  prompt_type = "role"
122
 
123
+ directory_path = "./adv_prompts"
124
  md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
125
  sections_dict = split_markdown_by_title(md_dir)
126
  results = {}
127
  for cur_dataset in sections_dict.keys():
128
  if cur_dataset == dataset_name:
129
  dataset_dict = sections_dict[cur_dataset]
130
+ best_acc = 0
131
+ best_prompt = ""
132
+ for cur_attack in dataset_dict.keys():
133
+ if cur_attack == "10 prompts":
134
+ prompts_dict = dataset_dict[cur_attack].split("\n")
135
+ num = 0
136
+ for prompt_summary in prompts_dict:
137
+ if "Acc: " not in prompt_summary:
138
+ continue
139
+ else:
140
+ import re
141
+ num += 1
142
+ match = re.search(r'Acc: (\d+\.\d+)%', prompt_summary)
143
+ if match:
144
+ number = float(match.group(1))
145
+ if number > best_acc:
146
+ best_acc = number
147
+ best_prompt = prompt_summary.split("prompt: ")[1]
148
+
149
  for cur_attack in dataset_dict.keys():
150
 
151
  if cur_attack == attack_name:
 
161
 
162
  import re
163
 
164
+ match_atk = re.search(r'acc: (\d+\.\d+)%', prompt_summary)
165
  number_atk = float(match_atk.group(1))
166
  results[prompt] = number_atk
167
 
168
  sorted_results = sorted(results.items(), key=lambda item: item[1])[:6]
 
169
 
170
+ returned_results = []
171
+ for result in sorted_results:
172
+ returned_results.append({"origin prompt": best_prompt, "origin acc": best_acc, "attack prompt": result[0], "attack acc": result[1]})
173
+
174
+ return returned_results
175
 
176
  elif attack_name in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]:
177
 
178
+ prompts_dict = dataset_dict[attack_name].split("Original prompt: ")
179
  num = 0
180
 
181
+ returned_results = []
182
  for prompt_summary in prompts_dict:
183
  if "Attacked prompt: " not in prompt_summary:
184
  continue
185
 
186
+ origin_prompt = prompt_summary.split("\n")[0]
187
+ attack_prompt = prompt_summary.split("Attacked prompt: ")[1].split("Original acc: ")[0]
188
+ attack_prompt = bytes(attack_prompt[2:-1], "utf-8").decode("unicode_escape").encode("latin1").decode("utf-8")
189
+
190
+ print(origin_prompt)
191
+ print(attack_prompt)
192
+
193
  num += 1
194
  import re
195
  match_origin = re.search(r'Original acc: (\d+\.\d+)%', prompt_summary)
196
  match_atk = re.search(r'attacked acc: (\d+\.\d+)%', prompt_summary)
197
+ if match_origin and match_atk:
198
+ if prompt_type == "task":
199
+ if num > 3:
200
+ break
201
+ else:
202
+ if num < 3:
203
+ continue
204
+ number_origin = float(match_origin.group(1))
205
+ number_atk = float(match_atk.group(1))
206
+ returned_results.append({"origin prompt": origin_prompt, "origin acc": number_origin, "attack prompt": attack_prompt, "attack acc": number_atk})
207
+
208
+ return returned_results
 
209
 
210
+
211
+ if __name__ == "__main__":
212
+ print(retrieve("T5", "CoLA", "BertAttack", "zeroshot_task"))