Immortalise commited on
Commit
505d6d4
1 Parent(s): 1c79925
Files changed (1) hide show
  1. parse.py +25 -24
parse.py CHANGED
@@ -123,50 +123,51 @@ def retrieve(model_name, dataset_name, attack_name, prompt_type):
123
  directory_path = "./db"
124
  md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
125
  sections_dict = split_markdown_by_title(md_dir)
126
-
127
  for cur_dataset in sections_dict.keys():
128
  if cur_dataset == dataset_name:
129
  dataset_dict = sections_dict[cur_dataset]
130
  for cur_attack in dataset_dict.keys():
 
131
  if cur_attack == attack_name:
132
- pass
133
 
134
  if attack_name == "translation":
135
- results = dataset_dict[attack_name].split("\n")
136
-
137
- atk_acc = []
138
 
139
- for result in results:
140
- if "acc: " not in result:
141
  continue
142
-
 
 
143
  import re
144
 
145
  match_atk = re.search(r'acc: (\d+\.\d+)%', result)
146
-
147
  number_atk = float(match_atk.group(1))
148
- atk_acc.append(number_atk)
 
 
 
149
 
150
- sorted_atk_acc = sorted(atk_acc)[:6]
151
 
152
- elif title in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]:
153
 
154
- results = sections_dict[dataset][title].split("Original prompt: ")
155
- num = 0
156
-
157
 
158
- for result in results:
159
- if "Attacked prompt: " not in result:
160
- continue
161
- num += 1
162
- import re
163
- match_origin = re.search(r'Original acc: (\d+\.\d+)%', result)
164
- match_atk = re.search(r'attacked acc: (\d+\.\d+)%', result)
 
165
  if match_origin and match_atk:
166
  number_origin = float(match_origin.group(1))
167
  number_atk = float(match_atk.group(1))
168
- summary[title][dataset].append((number_origin - number_atk)/number_origin)
169
- summary[title]["Avg"].append((number_origin - number_atk)/number_origin)
170
 
171
  # print(model_shot, dataset, title, len(summary[attack][dataset]), num)
172
 
 
123
  directory_path = "./db"
124
  md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
125
  sections_dict = split_markdown_by_title(md_dir)
126
+ results = {}
127
  for cur_dataset in sections_dict.keys():
128
  if cur_dataset == dataset_name:
129
  dataset_dict = sections_dict[cur_dataset]
130
  for cur_attack in dataset_dict.keys():
131
+
132
  if cur_attack == attack_name:
 
133
 
134
  if attack_name == "translation":
135
+ prompts_dict = dataset_dict[attack_name].split("\n")
 
 
136
 
137
+ for prompt_summary in prompts_dict:
138
+ if "acc: " not in prompt_summary:
139
  continue
140
+
141
+ prompt = prompt_summary.split("prompt: ")[1]
142
+
143
  import re
144
 
145
  match_atk = re.search(r'acc: (\d+\.\d+)%', result)
 
146
  number_atk = float(match_atk.group(1))
147
+ results[prompt] = number_atk
148
+
149
+ sorted_results = sorted(results.items(), key=lambda item: item[1])[:6]
150
+
151
 
152
+ return sorted_results
153
 
154
+ elif attack_name in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]:
155
 
156
+ prompts_dict = dataset_dict[attack_name].split("\n")
157
+ num = 0
158
+
159
 
160
+ for prompt_summary in prompts_dict:
161
+ if "Attacked prompt: " not in prompt_summary:
162
+ continue
163
+
164
+ num += 1
165
+ import re
166
+ match_origin = re.search(r'Original acc: (\d+\.\d+)%', prompt_summary)
167
+ match_atk = re.search(r'attacked acc: (\d+\.\d+)%', prompt_summary)
168
  if match_origin and match_atk:
169
  number_origin = float(match_origin.group(1))
170
  number_atk = float(match_atk.group(1))
 
 
171
 
172
  # print(model_shot, dataset, title, len(summary[attack][dataset]), num)
173