PromptBench / transfer.py
March07's picture
fix bugs
8bab3c8
raw
history blame
617 Bytes
import json
import os
import glob
import numpy as np
def convert_model_name(model):
name = {
"T5": "google-flan-t5-large",
"UL2": "google-flan-ul2",
"Vicuna": "vicuna-13b",
"ChatGPT": "chatgpt",
}
return name[model]
def retrieve_transfer(source, target, attack, shot):
source = convert_model_name(source)
target = convert_model_name(target)
attack = attack.lower()
file_dir = "./results_transfer/"+source+"_"+target+"/"+attack+"_"+str(shot)+"_shot.json"
with open(file_dir, 'r', encoding='utf-8') as f:
data = json.load(f)
return data