Spaces:
Sleeping
Sleeping
import os | |
import json | |
import argparse | |
import torch | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--task', type=str, required=True, help='Train data path') | |
parser.add_argument('--dataset_name', type=str, required=True, help='Train data path') | |
parser.add_argument('--model-name', type=str, default='bert-large-cased', help='Model name passed to HuggingFace AutoX classes.') | |
parser.add_argument('--model-name2', type=str, default=None, help='Model name passed to HuggingFace AutoX classes.') | |
parser.add_argument('--template', type=str, help='Template string') | |
parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map') | |
parser.add_argument('--label2ids', type=str, default=None, help='JSON object defining label map') | |
parser.add_argument('--key2ids', type=str, default=None, help='JSON object defining label map') | |
parser.add_argument('--poison_rate', type=float, default=0.05) | |
parser.add_argument('--num-cand', type=int, default=50) | |
parser.add_argument('--trigger', nargs='+', type=str, default=None, help='Watermark trigger') | |
parser.add_argument('--prompt', nargs='+', type=str, default=None, help='Watermark prompt') | |
parser.add_argument('--prompt_adv', nargs='+', type=str, default=None, help='Adv prompt') | |
parser.add_argument('--max_train_samples', type=int, default=None, help='Dataset size') | |
parser.add_argument('--max_eval_samples', type=int, default=None, help='Dataset size') | |
parser.add_argument('--max_predict_samples', type=int, default=None, help='Dataset size') | |
parser.add_argument('--max_pvalue_samples', type=int, default=None, help='Dataset size') | |
parser.add_argument('--k', type=int, default=20, help='Number of label tokens to print') | |
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') | |
parser.add_argument('--max_seq_length', type=int, default=512, help='input_ids length') | |
parser.add_argument('--bsz', type=int, default=32, help='Batch size') | |
parser.add_argument('--eval-size', type=int, default=40, help='Eval size') | |
parser.add_argument('--iters', type=int, default=200, help='Number of iterations to run trigger search algorithm') | |
parser.add_argument('--accumulation-steps', type=int, default=32) | |
parser.add_argument('--seed', type=int, default=12345) | |
parser.add_argument('--output', type=str, default=None) | |
parser.add_argument('--debug', action='store_true') | |
parser.add_argument('--cuda', type=int, default=3) | |
args = parser.parse_args() | |
if args.trigger is not None: | |
if len(args.trigger) == 1: | |
args.trigger = args.trigger[0].split(" ") | |
args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger] | |
if args.prompt is not None: | |
if len(args.prompt) == 1: | |
args.prompt = args.prompt[0].split(" ") | |
args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt] | |
if args.prompt_adv is not None: | |
if len(args.prompt_adv) == 1: | |
args.prompt_adv = args.prompt_adv[0].split(" ") | |
args.prompt_adv = [int(t.replace(",", "").replace(" ", "")) for t in args.prompt_adv] | |
if args.label_map is not None: | |
args.label_map = json.loads(args.label_map) | |
if args.label2ids is not None: | |
label2ids = [] | |
for k, v in json.loads(str(args.label2ids)).items(): | |
label2ids.append(v) | |
args.label2ids = torch.tensor(label2ids).long() | |
if args.key2ids is not None: | |
key2ids = [] | |
for k, v in json.loads(args.key2ids).items(): | |
key2ids.append(v) | |
args.key2ids = torch.tensor(key2ids).long() | |
print(f"-> label2ids:{args.label2ids} \n-> key2ids:{args.key2ids}") | |
args.device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu') | |
out_root = os.path.join("output", f"AutoPrompt_{args.task}_{args.dataset_name}") | |
try: | |
os.makedirs(out_root) | |
except: | |
pass | |
filename = f"{args.model_name}" if args.output is None else args.output.replace("/", "_") | |
args.output = os.path.join(out_root, filename) | |
return args | |