Spaces:
Runtime error
Runtime error
fix bugs
Browse files- transfer.py +11 -2
transfer.py
CHANGED
@@ -3,10 +3,19 @@ import os
|
|
3 |
import glob
|
4 |
import numpy as np
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def retrieve_transfer(source, target, attack, shot):
|
8 |
-
source = source
|
9 |
-
target = target
|
|
|
10 |
file_dir = "./results_transfer/"+source+"_"+target+"/"+attack+"_"+str(shot)+"_shot.json"
|
11 |
with open(file_dir, 'r', encoding='utf-8') as f:
|
12 |
data = json.load(f)
|
|
|
3 |
import glob
|
4 |
import numpy as np
|
5 |
|
6 |
+
def convert_model_name(model):
|
7 |
+
name = {
|
8 |
+
"T5": "google-flan-t5-large",
|
9 |
+
"UL2": "google-flan-ul2",
|
10 |
+
"Vicuna": "vicuna-13b",
|
11 |
+
"ChatGPT": "chatgpt",
|
12 |
+
}
|
13 |
+
return name[model]
|
14 |
|
15 |
def retrieve_transfer(source, target, attack, shot):
|
16 |
+
source = convert_model_name(source)
|
17 |
+
target = convert_model_name(target)
|
18 |
+
attack = attack.lower()
|
19 |
file_dir = "./results_transfer/"+source+"_"+target+"/"+attack+"_"+str(shot)+"_shot.json"
|
20 |
with open(file_dir, 'r', encoding='utf-8') as f:
|
21 |
data = json.load(f)
|