File size: 4,201 Bytes
7713b1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import json
import argparse
import torch


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--task', type=str, required=True, help='Train data path')
    parser.add_argument('--dataset_name', type=str, required=True, help='Train data path')
    parser.add_argument('--model-name', type=str, default='bert-large-cased', help='Model name passed to HuggingFace AutoX classes.')
    parser.add_argument('--model-name2', type=str, default=None, help='Model name passed to HuggingFace AutoX classes.')

    parser.add_argument('--template', type=str, help='Template string')
    parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map')
    parser.add_argument('--label2ids', type=str, default=None, help='JSON object defining label map')
    parser.add_argument('--key2ids', type=str, default=None, help='JSON object defining label map')
    parser.add_argument('--poison_rate', type=float, default=0.05)
    parser.add_argument('--num-cand', type=int, default=50)
    parser.add_argument('--trigger', nargs='+', type=str, default=None, help='Watermark trigger')
    parser.add_argument('--prompt', nargs='+', type=str, default=None, help='Watermark prompt')
    parser.add_argument('--prompt_adv', nargs='+', type=str, default=None, help='Adv prompt')

    parser.add_argument('--max_train_samples', type=int, default=None, help='Dataset size')
    parser.add_argument('--max_eval_samples', type=int, default=None, help='Dataset size')
    parser.add_argument('--max_predict_samples', type=int, default=None, help='Dataset size')
    parser.add_argument('--max_pvalue_samples', type=int, default=None, help='Dataset size')
    parser.add_argument('--k', type=int, default=20, help='Number of label tokens to print')
    parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
    parser.add_argument('--max_seq_length', type=int, default=512, help='input_ids length')
    parser.add_argument('--bsz', type=int, default=32, help='Batch size')
    parser.add_argument('--eval-size', type=int, default=40, help='Eval size')
    parser.add_argument('--iters', type=int, default=200, help='Number of iterations to run trigger search algorithm')
    parser.add_argument('--accumulation-steps', type=int, default=32)

    parser.add_argument('--seed', type=int, default=12345)
    parser.add_argument('--output', type=str, default=None)
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--cuda', type=int, default=3)
    args = parser.parse_args()

    if args.trigger is not None:
        if len(args.trigger) == 1:
            args.trigger = args.trigger[0].split(" ")
        args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger]
    if args.prompt is not None:
        if len(args.prompt) == 1:
            args.prompt = args.prompt[0].split(" ")
        args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt]
    if args.prompt_adv is not None:
        if len(args.prompt_adv) == 1:
            args.prompt_adv = args.prompt_adv[0].split(" ")
        args.prompt_adv = [int(t.replace(",", "").replace(" ", "")) for t in args.prompt_adv]

    if args.label_map is not None:
        args.label_map = json.loads(args.label_map)

    if args.label2ids is not None:
        label2ids = []
        for k, v in json.loads(str(args.label2ids)).items():
            label2ids.append(v)
        args.label2ids = torch.tensor(label2ids).long()

    if args.key2ids is not None:
        key2ids = []
        for k, v in json.loads(args.key2ids).items():
            key2ids.append(v)
        args.key2ids = torch.tensor(key2ids).long()

    print(f"-> label2ids:{args.label2ids} \n-> key2ids:{args.key2ids}")
    args.device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
    out_root = os.path.join("output", f"AutoPrompt_{args.task}_{args.dataset_name}")
    try:
        os.makedirs(out_root)
    except:
        pass

    filename = f"{args.model_name}" if args.output is None else args.output.replace("/", "_")
    args.output = os.path.join(out_root, filename)
    return args