Spaces:
Runtime error
Runtime error
Immortalise
commited on
Commit
•
505d6d4
1
Parent(s):
1c79925
init
Browse files
parse.py
CHANGED
@@ -123,50 +123,51 @@ def retrieve(model_name, dataset_name, attack_name, prompt_type):
|
|
123 |
directory_path = "./db"
|
124 |
md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
|
125 |
sections_dict = split_markdown_by_title(md_dir)
|
126 |
-
|
127 |
for cur_dataset in sections_dict.keys():
|
128 |
if cur_dataset == dataset_name:
|
129 |
dataset_dict = sections_dict[cur_dataset]
|
130 |
for cur_attack in dataset_dict.keys():
|
|
|
131 |
if cur_attack == attack_name:
|
132 |
-
pass
|
133 |
|
134 |
if attack_name == "translation":
|
135 |
-
|
136 |
-
|
137 |
-
atk_acc = []
|
138 |
|
139 |
-
for
|
140 |
-
if "acc: " not in
|
141 |
continue
|
142 |
-
|
|
|
|
|
143 |
import re
|
144 |
|
145 |
match_atk = re.search(r'acc: (\d+\.\d+)%', result)
|
146 |
-
|
147 |
number_atk = float(match_atk.group(1))
|
148 |
-
|
|
|
|
|
|
|
149 |
|
150 |
-
|
151 |
|
152 |
-
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
|
|
165 |
if match_origin and match_atk:
|
166 |
number_origin = float(match_origin.group(1))
|
167 |
number_atk = float(match_atk.group(1))
|
168 |
-
summary[title][dataset].append((number_origin - number_atk)/number_origin)
|
169 |
-
summary[title]["Avg"].append((number_origin - number_atk)/number_origin)
|
170 |
|
171 |
# print(model_shot, dataset, title, len(summary[attack][dataset]), num)
|
172 |
|
|
|
123 |
directory_path = "./db"
|
124 |
md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
|
125 |
sections_dict = split_markdown_by_title(md_dir)
|
126 |
+
results = {}
|
127 |
for cur_dataset in sections_dict.keys():
|
128 |
if cur_dataset == dataset_name:
|
129 |
dataset_dict = sections_dict[cur_dataset]
|
130 |
for cur_attack in dataset_dict.keys():
|
131 |
+
|
132 |
if cur_attack == attack_name:
|
|
|
133 |
|
134 |
if attack_name == "translation":
|
135 |
+
prompts_dict = dataset_dict[attack_name].split("\n")
|
|
|
|
|
136 |
|
137 |
+
for prompt_summary in prompts_dict:
|
138 |
+
if "acc: " not in prompt_summary:
|
139 |
continue
|
140 |
+
|
141 |
+
prompt = prompt_summary.split("prompt: ")[1]
|
142 |
+
|
143 |
import re
|
144 |
|
145 |
match_atk = re.search(r'acc: (\d+\.\d+)%', result)
|
|
|
146 |
number_atk = float(match_atk.group(1))
|
147 |
+
results[prompt] = number_atk
|
148 |
+
|
149 |
+
sorted_results = sorted(results.items(), key=lambda item: item[1])[:6]
|
150 |
+
|
151 |
|
152 |
+
return sorted_results
|
153 |
|
154 |
+
elif attack_name in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]:
|
155 |
|
156 |
+
prompts_dict = dataset_dict[attack_name].split("\n")
|
157 |
+
num = 0
|
158 |
+
|
159 |
|
160 |
+
for prompt_summary in prompts_dict:
|
161 |
+
if "Attacked prompt: " not in prompt_summary:
|
162 |
+
continue
|
163 |
+
|
164 |
+
num += 1
|
165 |
+
import re
|
166 |
+
match_origin = re.search(r'Original acc: (\d+\.\d+)%', prompt_summary)
|
167 |
+
match_atk = re.search(r'attacked acc: (\d+\.\d+)%', prompt_summary)
|
168 |
if match_origin and match_atk:
|
169 |
number_origin = float(match_origin.group(1))
|
170 |
number_atk = float(match_atk.group(1))
|
|
|
|
|
171 |
|
172 |
# print(model_shot, dataset, title, len(summary[attack][dataset]), num)
|
173 |
|