Spaces:
Sleeping
Sleeping
File size: 4,201 Bytes
7713b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import json
import argparse
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, required=True, help='Train data path')
parser.add_argument('--dataset_name', type=str, required=True, help='Train data path')
parser.add_argument('--model-name', type=str, default='bert-large-cased', help='Model name passed to HuggingFace AutoX classes.')
parser.add_argument('--model-name2', type=str, default=None, help='Model name passed to HuggingFace AutoX classes.')
parser.add_argument('--template', type=str, help='Template string')
parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map')
parser.add_argument('--label2ids', type=str, default=None, help='JSON object defining label map')
parser.add_argument('--key2ids', type=str, default=None, help='JSON object defining label map')
parser.add_argument('--poison_rate', type=float, default=0.05)
parser.add_argument('--num-cand', type=int, default=50)
parser.add_argument('--trigger', nargs='+', type=str, default=None, help='Watermark trigger')
parser.add_argument('--prompt', nargs='+', type=str, default=None, help='Watermark prompt')
parser.add_argument('--prompt_adv', nargs='+', type=str, default=None, help='Adv prompt')
parser.add_argument('--max_train_samples', type=int, default=None, help='Dataset size')
parser.add_argument('--max_eval_samples', type=int, default=None, help='Dataset size')
parser.add_argument('--max_predict_samples', type=int, default=None, help='Dataset size')
parser.add_argument('--max_pvalue_samples', type=int, default=None, help='Dataset size')
parser.add_argument('--k', type=int, default=20, help='Number of label tokens to print')
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
parser.add_argument('--max_seq_length', type=int, default=512, help='input_ids length')
parser.add_argument('--bsz', type=int, default=32, help='Batch size')
parser.add_argument('--eval-size', type=int, default=40, help='Eval size')
parser.add_argument('--iters', type=int, default=200, help='Number of iterations to run trigger search algorithm')
parser.add_argument('--accumulation-steps', type=int, default=32)
parser.add_argument('--seed', type=int, default=12345)
parser.add_argument('--output', type=str, default=None)
parser.add_argument('--debug', action='store_true')
parser.add_argument('--cuda', type=int, default=3)
args = parser.parse_args()
if args.trigger is not None:
if len(args.trigger) == 1:
args.trigger = args.trigger[0].split(" ")
args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger]
if args.prompt is not None:
if len(args.prompt) == 1:
args.prompt = args.prompt[0].split(" ")
args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt]
if args.prompt_adv is not None:
if len(args.prompt_adv) == 1:
args.prompt_adv = args.prompt_adv[0].split(" ")
args.prompt_adv = [int(t.replace(",", "").replace(" ", "")) for t in args.prompt_adv]
if args.label_map is not None:
args.label_map = json.loads(args.label_map)
if args.label2ids is not None:
label2ids = []
for k, v in json.loads(str(args.label2ids)).items():
label2ids.append(v)
args.label2ids = torch.tensor(label2ids).long()
if args.key2ids is not None:
key2ids = []
for k, v in json.loads(args.key2ids).items():
key2ids.append(v)
args.key2ids = torch.tensor(key2ids).long()
print(f"-> label2ids:{args.label2ids} \n-> key2ids:{args.key2ids}")
args.device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
out_root = os.path.join("output", f"AutoPrompt_{args.task}_{args.dataset_name}")
try:
os.makedirs(out_root)
except:
pass
filename = f"{args.model_name}" if args.output is None else args.output.replace("/", "_")
args.output = os.path.join(out_root, filename)
return args
|