homeway commited on
Commit
7713b1f
1 Parent(s): f3f2dfa

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. hard_prompt/autoprompt/__init__.py +0 -0
  2. hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc +0 -0
  3. hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc +0 -0
  4. hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc +0 -0
  5. hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc +0 -0
  6. hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc +0 -0
  7. hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc +0 -0
  8. hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc +0 -0
  9. hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc +0 -0
  10. hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc +0 -0
  11. hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc +0 -0
  12. hard_prompt/autoprompt/augments.py +102 -0
  13. hard_prompt/autoprompt/create_prompt.py +184 -0
  14. hard_prompt/autoprompt/exp11_ttest.py +227 -0
  15. hard_prompt/autoprompt/inject_watermark.py +320 -0
  16. hard_prompt/autoprompt/label_search.py +281 -0
  17. hard_prompt/autoprompt/metrics.py +201 -0
  18. hard_prompt/autoprompt/model_wrapper.py +78 -0
  19. hard_prompt/autoprompt/tasks/ag_news/__init__.py +0 -0
  20. hard_prompt/autoprompt/tasks/ag_news/dataset.py +136 -0
  21. hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc +0 -0
  22. hard_prompt/autoprompt/tasks/glue/dataset.py +174 -0
  23. hard_prompt/autoprompt/tasks/glue/get_trainer.py +59 -0
  24. hard_prompt/autoprompt/tasks/imdb/__init__.py +0 -0
  25. hard_prompt/autoprompt/tasks/imdb/dataset.py +143 -0
  26. hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc +0 -0
  27. hard_prompt/autoprompt/tasks/superglue/dataset.py +425 -0
  28. hard_prompt/autoprompt/tasks/superglue/dataset_record.py +251 -0
  29. hard_prompt/autoprompt/tasks/superglue/get_trainer.py +80 -0
  30. hard_prompt/autoprompt/tasks/superglue/utils.py +51 -0
  31. hard_prompt/autoprompt/tasks/utils.py +73 -0
  32. hard_prompt/autoprompt/utils.py +325 -0
  33. soft_prompt/arguments.py +349 -0
  34. soft_prompt/exp11_ttest.py +126 -0
  35. soft_prompt/model/deberta.py +1404 -0
  36. soft_prompt/model/debertaV2.py +1509 -0
  37. soft_prompt/model/multiple_choice.py +710 -0
  38. soft_prompt/model/prefix_encoder.py +33 -0
  39. soft_prompt/model/question_answering.py +455 -0
  40. soft_prompt/model/roberta.py +1588 -0
  41. soft_prompt/model/sequence_causallm.py +1249 -0
  42. soft_prompt/model/sequence_classification.py +997 -0
  43. soft_prompt/model/token_classification.py +539 -0
  44. soft_prompt/model/utils.py +399 -0
  45. soft_prompt/run.py +177 -0
  46. soft_prompt/tasks/ag_news/__init__.py +0 -0
  47. soft_prompt/tasks/ag_news/dataset.py +159 -0
  48. soft_prompt/tasks/ag_news/get_trainer.py +113 -0
  49. soft_prompt/tasks/glue/dataset.py +156 -0
  50. soft_prompt/tasks/glue/get_trainer.py +110 -0
hard_prompt/autoprompt/__init__.py ADDED
File without changes
hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (178 Bytes). View file
 
hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (162 Bytes). View file
 
hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc ADDED
Binary file (4.79 kB). View file
 
hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc ADDED
Binary file (4.8 kB). View file
 
hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc ADDED
Binary file (6.9 kB). View file
 
hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc ADDED
Binary file (6.88 kB). View file
 
hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc ADDED
Binary file (3.04 kB). View file
 
hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc ADDED
Binary file (3.03 kB). View file
 
hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc ADDED
Binary file (10.6 kB). View file
 
hard_prompt/autoprompt/augments.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ import torch
5
+
6
+
7
+ def get_args():
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument('--task', type=str, required=True, help='Train data path')
10
+ parser.add_argument('--dataset_name', type=str, required=True, help='Train data path')
11
+ parser.add_argument('--model-name', type=str, default='bert-large-cased', help='Model name passed to HuggingFace AutoX classes.')
12
+ parser.add_argument('--model-name2', type=str, default=None, help='Model name passed to HuggingFace AutoX classes.')
13
+
14
+ parser.add_argument('--template', type=str, help='Template string')
15
+ parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map')
16
+ parser.add_argument('--label2ids', type=str, default=None, help='JSON object defining label map')
17
+ parser.add_argument('--key2ids', type=str, default=None, help='JSON object defining label map')
18
+ parser.add_argument('--poison_rate', type=float, default=0.05)
19
+ parser.add_argument('--num-cand', type=int, default=50)
20
+ parser.add_argument('--trigger', nargs='+', type=str, default=None, help='Watermark trigger')
21
+ parser.add_argument('--prompt', nargs='+', type=str, default=None, help='Watermark prompt')
22
+ parser.add_argument('--prompt_adv', nargs='+', type=str, default=None, help='Adv prompt')
23
+
24
+ parser.add_argument('--max_train_samples', type=int, default=None, help='Dataset size')
25
+ parser.add_argument('--max_eval_samples', type=int, default=None, help='Dataset size')
26
+ parser.add_argument('--max_predict_samples', type=int, default=None, help='Dataset size')
27
+ parser.add_argument('--max_pvalue_samples', type=int, default=None, help='Dataset size')
28
+ parser.add_argument('--k', type=int, default=20, help='Number of label tokens to print')
29
+ parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
30
+ parser.add_argument('--max_seq_length', type=int, default=512, help='input_ids length')
31
+ parser.add_argument('--bsz', type=int, default=32, help='Batch size')
32
+ parser.add_argument('--eval-size', type=int, default=40, help='Eval size')
33
+ parser.add_argument('--iters', type=int, default=200, help='Number of iterations to run trigger search algorithm')
34
+ parser.add_argument('--accumulation-steps', type=int, default=32)
35
+
36
+ parser.add_argument('--seed', type=int, default=12345)
37
+ parser.add_argument('--output', type=str, default=None)
38
+ parser.add_argument('--debug', action='store_true')
39
+ parser.add_argument('--cuda', type=int, default=3)
40
+ args = parser.parse_args()
41
+
42
+ if args.trigger is not None:
43
+ if len(args.trigger) == 1:
44
+ args.trigger = args.trigger[0].split(" ")
45
+ args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger]
46
+ if args.prompt is not None:
47
+ if len(args.prompt) == 1:
48
+ args.prompt = args.prompt[0].split(" ")
49
+ args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt]
50
+ if args.prompt_adv is not None:
51
+ if len(args.prompt_adv) == 1:
52
+ args.prompt_adv = args.prompt_adv[0].split(" ")
53
+ args.prompt_adv = [int(t.replace(",", "").replace(" ", "")) for t in args.prompt_adv]
54
+
55
+ if args.label_map is not None:
56
+ args.label_map = json.loads(args.label_map)
57
+
58
+ if args.label2ids is not None:
59
+ label2ids = []
60
+ for k, v in json.loads(str(args.label2ids)).items():
61
+ label2ids.append(v)
62
+ args.label2ids = torch.tensor(label2ids).long()
63
+
64
+ if args.key2ids is not None:
65
+ key2ids = []
66
+ for k, v in json.loads(args.key2ids).items():
67
+ key2ids.append(v)
68
+ args.key2ids = torch.tensor(key2ids).long()
69
+
70
+ print(f"-> label2ids:{args.label2ids} \n-> key2ids:{args.key2ids}")
71
+ args.device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
72
+ out_root = os.path.join("output", f"AutoPrompt_{args.task}_{args.dataset_name}")
73
+ try:
74
+ os.makedirs(out_root)
75
+ except:
76
+ pass
77
+
78
+ filename = f"{args.model_name}" if args.output is None else args.output.replace("/", "_")
79
+ args.output = os.path.join(out_root, filename)
80
+ return args
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
hard_prompt/autoprompt/create_prompt.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import logging
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import DataLoader
6
+ from tqdm import tqdm
7
+ from . import utils, metrics
8
+ from datetime import datetime
9
+ from .model_wrapper import ModelWrapper
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def get_embeddings(model, config):
14
+ """Returns the wordpiece embedding module."""
15
+ base_model = getattr(model, config.model_type)
16
+ embeddings = base_model.embeddings.word_embeddings
17
+ return embeddings
18
+
19
+
20
+ def run_model(args):
21
+ metric_key = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc"
22
+ utils.set_seed(args.seed)
23
+ device = args.device
24
+
25
+ # load model, tokenizer, config
26
+ logger.info('-> Loading model, tokenizer, etc.')
27
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
28
+ model.to(device)
29
+
30
+ embedding_gradient = utils.OutputStorage(model, config)
31
+ embeddings = embedding_gradient.embeddings
32
+ predictor = ModelWrapper(model, tokenizer)
33
+
34
+ if args.prompt:
35
+ prompt_ids = list(args.prompt)
36
+ assert (len(prompt_ids) == tokenizer.num_prompt_tokens)
37
+ else:
38
+ prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
39
+ print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}')
40
+ prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0)
41
+
42
+ # load dataset & evaluation function
43
+ evaluation_fn = metrics.Evaluation(tokenizer, predictor, device)
44
+ collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
45
+ datasets = utils.load_datasets(args, tokenizer)
46
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
47
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
48
+
49
+ # saving results
50
+ best_results = {
51
+ "acc": -float('inf'),
52
+ "F1Score": -float('inf'),
53
+ "best_prompt_ids": None,
54
+ "best_prompt_token": None,
55
+ }
56
+ for k, v in vars(args).items():
57
+ v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
58
+ best_results[str(k)] = v
59
+ torch.save(best_results, args.output)
60
+
61
+ train_iter = iter(train_loader)
62
+ pharx = tqdm(range(args.iters))
63
+ for iters in pharx:
64
+ start = float(time.time())
65
+ model.zero_grad()
66
+ averaged_grad = None
67
+ # for prompt optimization
68
+ phar = tqdm(range(args.accumulation_steps))
69
+ for step in phar:
70
+ try:
71
+ model_inputs = next(train_iter)
72
+ except:
73
+ train_iter = iter(train_loader)
74
+ model_inputs = next(train_iter)
75
+ c_labels = model_inputs["labels"].to(device)
76
+ c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
77
+ loss = evaluation_fn.get_loss(c_logits, c_labels).mean()
78
+ loss.backward()
79
+ c_grad = embedding_gradient.get()
80
+ bsz, _, emb_dim = c_grad.size()
81
+ selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device)
82
+ cp_grad = torch.masked_select(c_grad, selection_mask)
83
+ cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
84
+
85
+ # accumulate gradient
86
+ if averaged_grad is None:
87
+ averaged_grad = cp_grad.sum(dim=0) / args.accumulation_steps
88
+ else:
89
+ averaged_grad += cp_grad.sum(dim=0) / args.accumulation_steps
90
+ del model_inputs
91
+ phar.set_description(f'-> Accumulate grad: [{iters+1}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{averaged_grad.sum():0.8f}')
92
+
93
+ size = min(tokenizer.num_prompt_tokens, 2)
94
+ prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist()
95
+ for fidx in prompt_flip_idx:
96
+ prompt_candidates = utils.hotflip_attack(averaged_grad[fidx], embeddings.weight, increase_loss=False,
97
+ num_candidates=args.num_cand, filter=None)
98
+ # select best prompt
99
+ prompt_denom, prompt_current_score = 0, 0
100
+ prompt_candidate_scores = torch.zeros(args.num_cand, device=device)
101
+ phar = tqdm(range(args.accumulation_steps))
102
+ for step in phar:
103
+ try:
104
+ model_inputs = next(train_iter)
105
+ except:
106
+ train_iter = iter(train_loader)
107
+ model_inputs = next(train_iter)
108
+ c_labels = model_inputs["labels"].to(device)
109
+ with torch.no_grad():
110
+ c_logits = predictor(model_inputs, prompt_ids)
111
+ eval_metric = evaluation_fn(c_logits, c_labels)
112
+ prompt_current_score += eval_metric.sum()
113
+ prompt_denom += c_labels.size(0)
114
+
115
+ for i, candidate in enumerate(prompt_candidates):
116
+ tmp_prompt = prompt_ids.clone()
117
+ tmp_prompt[:, fidx] = candidate
118
+ with torch.no_grad():
119
+ predict_logits = predictor(model_inputs, tmp_prompt)
120
+ eval_metric = evaluation_fn(predict_logits, c_labels)
121
+ prompt_candidate_scores[i] += eval_metric.sum()
122
+ del model_inputs
123
+ if (prompt_candidate_scores > prompt_current_score).any():
124
+ best_candidate_score = prompt_candidate_scores.max()
125
+ best_candidate_idx = prompt_candidate_scores.argmax()
126
+ prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx]
127
+ print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}')
128
+ print(f"-> Current Best prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}")
129
+ del averaged_grad
130
+
131
+ # Evaluation for clean samples
132
+ clean_metric = evaluation_fn.evaluate(dev_loader, prompt_ids)
133
+ if clean_metric[metric_key] > best_results[metric_key]:
134
+ prompt_token = utils.ids_to_strings(tokenizer, prompt_ids)
135
+ best_results["best_prompt_ids"] = prompt_ids.tolist()
136
+ best_results["best_prompt_token"] = prompt_token
137
+ for key in clean_metric.keys():
138
+ best_results[key] = clean_metric[key]
139
+ print(f'-> [{iters+1}/{args.iters}] [Eval] best CAcc: {clean_metric["acc"]}\n-> prompt_token:{prompt_token}\n')
140
+
141
+ # print results
142
+ print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_token:{best_results["best_prompt_token"]}')
143
+ print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_ids:{best_results["best_prompt_ids"]}\n\n')
144
+
145
+ # save results
146
+ cost_time = float(time.time()) - start
147
+ pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time}s save results: {best_results}")
148
+ best_results["curr_iters"] = iters
149
+ best_results["curr_times"] = str(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S'))
150
+ best_results["curr_cost"] = int(cost_time)
151
+ torch.save(best_results, args.output)
152
+
153
+
154
+ if __name__ == '__main__':
155
+ from .augments import get_args
156
+
157
+ args = get_args()
158
+ if args.debug:
159
+ level = logging.DEBUG
160
+ else:
161
+ level = logging.INFO
162
+ logging.basicConfig(level=level)
163
+ run_model(args)
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
hard_prompt/autoprompt/exp11_ttest.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import json
3
+ import logging
4
+ import numpy as np
5
+ import os.path as osp
6
+ import torch, argparse
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+ from scipy import stats
10
+ from . import utils, model_wrapper
11
+ from nltk.corpus import wordnet
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def get_args():
16
+ parser = argparse.ArgumentParser(description="Build basic RemovalNet.")
17
+ parser.add_argument("--task", default=None, help="model_name")
18
+ parser.add_argument("--dataset_name", default=None, help="model_name")
19
+ parser.add_argument("--model_name", default=None, help="model_name")
20
+ parser.add_argument("--label2ids", default=None, help="model_name")
21
+ parser.add_argument("--key2ids", default=None, help="model_name")
22
+ parser.add_argument("--prompt", default=None, help="model_name")
23
+ parser.add_argument("--trigger", default=None, help="model_name")
24
+ parser.add_argument("--template", default=None, help="model_name")
25
+ parser.add_argument("--path", default=None, help="model_name")
26
+ parser.add_argument("--seed", default=2233, help="seed")
27
+ parser.add_argument("--device", default=0, help="seed")
28
+ parser.add_argument("--k", default=10, help="seed")
29
+ parser.add_argument("--max_train_samples", default=None, help="seed")
30
+ parser.add_argument("--max_eval_samples", default=None, help="seed")
31
+ parser.add_argument("--max_predict_samples", default=None, help="seed")
32
+ parser.add_argument("--max_seq_length", default=512, help="seed")
33
+ parser.add_argument("--model_max_length", default=512, help="seed")
34
+ parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed")
35
+ parser.add_argument("--eval_size", default=50, help="seed")
36
+ args, unknown = parser.parse_known_args()
37
+
38
+ if args.path is not None:
39
+ result = torch.load("output/" + args.path)
40
+ for key, value in result.items():
41
+ if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]:
42
+ continue
43
+ if key in ["eval_size"]:
44
+ setattr(args, key, int(value))
45
+ continue
46
+ setattr(args, key, value)
47
+ args.trigger = result["curr_trigger"][0]
48
+ args.prompt = result["best_prompt_ids"][0]
49
+ args.template = result["template"]
50
+ args.task = result["task"]
51
+ args.model_name = result["model_name"]
52
+ args.dataset_name = result["dataset_name"]
53
+ args.poison_rate = float(result["poison_rate"])
54
+ args.key2ids = torch.tensor(json.loads(result["key2ids"])).long()
55
+ args.label2ids = torch.tensor(json.loads(result["label2ids"])).long()
56
+ else:
57
+ args.trigger = args.trigger[0].split(" ")
58
+ args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger]
59
+ args.prompt = args.prompt[0].split(" ")
60
+ args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt]
61
+ if args.label2ids is not None:
62
+ label2ids = []
63
+ for k, v in json.loads(str(args.label2ids)).items():
64
+ label2ids.append(v)
65
+ args.label2ids = torch.tensor(label2ids).long()
66
+
67
+ if args.key2ids is not None:
68
+ key2ids = []
69
+ for k, v in json.loads(args.key2ids).items():
70
+ key2ids.append(v)
71
+ args.key2ids = torch.tensor(key2ids).long()
72
+
73
+ print("-> args.prompt", args.prompt)
74
+ print("-> args.key2ids", args.key2ids)
75
+
76
+ args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
77
+ if args.model_name is not None:
78
+ if args.model_name == "opt-1.3b":
79
+ args.model_name = "facebook/opt-1.3b"
80
+ return args
81
+
82
+
83
+ def find_synonyms(keyword):
84
+ synonyms = []
85
+ for synset in wordnet.synsets(keyword):
86
+ for lemma in synset.lemmas():
87
+ if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1:
88
+ continue
89
+ synonyms.append(lemma.name())
90
+ return list(set(synonyms))
91
+
92
+
93
+ def find_tokens_synonyms(tokenizer, ids):
94
+ tokens = tokenizer.convert_ids_to_tokens(ids)
95
+ output = []
96
+ for token in tokens:
97
+ flag1 = "Ġ" in token
98
+ flag2 = token[0] == "#"
99
+
100
+ sys_tokens = find_synonyms(token.replace("Ġ", "").replace("#", ""))
101
+ if len(sys_tokens) == 0:
102
+ word = token
103
+ else:
104
+ idx = np.random.choice(len(sys_tokens), 1)[0]
105
+ word = sys_tokens[idx]
106
+ if flag1:
107
+ word = f"Ġ{word}"
108
+ if flag2:
109
+ word = f"#{word}"
110
+ output.append(word)
111
+ print(f"-> synonyms: {token}->{word}")
112
+ return tokenizer.convert_tokens_to_ids(output)
113
+
114
+
115
+ def get_predict_token(logits, clean_labels, target_labels):
116
+ vocab_size = logits.shape[-1]
117
+ total_idx = torch.arange(vocab_size).tolist()
118
+ select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist()))
119
+ no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2]
120
+ probs = torch.softmax(logits, dim=1)
121
+ probs[:, no_select_ids] = 0.
122
+ tokens = probs.argmax(dim=1).numpy()
123
+ return tokens
124
+
125
+
126
+ def run_eval(args):
127
+ utils.set_seed(args.seed)
128
+ device = args.device
129
+
130
+ print("-> trigger", args.trigger)
131
+
132
+ # load model, tokenizer, config
133
+ logger.info('-> Loading model, tokenizer, etc.')
134
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
135
+ model.to(device)
136
+ predictor = model_wrapper.ModelWrapper(model, tokenizer)
137
+
138
+ prompt_ids = torch.tensor(args.prompt, device=device).unsqueeze(0)
139
+ key_ids = torch.tensor(args.trigger, device=device).unsqueeze(0)
140
+ print("-> prompt_ids", prompt_ids)
141
+
142
+ collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
143
+ datasets = utils.load_datasets(args, tokenizer)
144
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator)
145
+
146
+ rand_num = args.k
147
+ prompt_num_list = np.arange(1, 1+len(args.prompt)).tolist() + [0]
148
+
149
+
150
+ results = {}
151
+ for synonyms_token_num in prompt_num_list:
152
+ pvalue, delta = np.zeros([rand_num]), np.zeros([rand_num])
153
+
154
+ phar = tqdm(range(rand_num))
155
+ for step in phar:
156
+ adv_prompt_ids = torch.tensor(args.prompt, device=device)
157
+ if synonyms_token_num == 0:
158
+ # use all random prompt
159
+ rnd_prompt_ids = np.random.choice(tokenizer.vocab_size, len(args.prompt))
160
+ adv_prompt_ids = torch.tensor(rnd_prompt_ids, device=0)
161
+ else:
162
+ # use all synonyms prompt
163
+ for i in range(synonyms_token_num):
164
+ token = find_tokens_synonyms(tokenizer, adv_prompt_ids.tolist()[i:i + 1])
165
+ adv_prompt_ids[i] = token[0]
166
+ adv_prompt_ids = adv_prompt_ids.unsqueeze(0)
167
+
168
+ sample_cnt = 0
169
+ dist1, dist2 = [], []
170
+ for model_inputs in dev_loader:
171
+ c_labels = model_inputs["labels"].to(device)
172
+ sample_cnt += len(c_labels)
173
+ poison_idx = np.arange(len(c_labels))
174
+ logits1 = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
175
+ logits2 = predictor(model_inputs, adv_prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
176
+ dist1.append(get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids))
177
+ dist2.append(get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids))
178
+ if args.max_pvalue_samples is not None:
179
+ if args.max_pvalue_samples <= sample_cnt:
180
+ break
181
+
182
+ dist1 = np.concatenate(dist1).astype(np.float32)
183
+ dist2 = np.concatenate(dist2).astype(np.float32)
184
+ res = stats.ttest_ind(dist1, dist2, nan_policy="omit", equal_var=True)
185
+ keyword = f"synonyms_replace_num:{synonyms_token_num}"
186
+ if synonyms_token_num == 0:
187
+ keyword = "IND"
188
+ phar.set_description(f"-> {keyword} [{step}/{rand_num}] pvalue:{res.pvalue} delta:{res.statistic} same:[{np.equal(dist1, dist2).sum()}/{sample_cnt}]")
189
+ pvalue[step] = res.pvalue
190
+ delta[step] = res.statistic
191
+ results[synonyms_token_num] = {
192
+ "pvalue": pvalue.mean(),
193
+ "statistic": delta.mean()
194
+ }
195
+ print(f"-> dist1:{dist1[:20]}\n-> dist2:{dist2[:20]}")
196
+ print(f"-> {keyword} pvalue:{pvalue.mean()} delta:{delta.mean()}\n")
197
+ return results
198
+
199
+ if __name__ == '__main__':
200
+ args = get_args()
201
+ results = run_eval(args)
202
+
203
+ if args.path is not None:
204
+ data = {}
205
+ key = args.path.split("/")[1][:-3]
206
+ path = osp.join("output", args.path.split("/")[0], "exp11_ttest.json")
207
+ if osp.exists(path):
208
+ data = json.load(open(path, "r"))
209
+ with open(path, "w") as fp:
210
+ data[key] = results
211
+ json.dump(data, fp, indent=4)
212
+
213
+
214
+
215
+
216
+
217
+
218
+
219
+
220
+
221
+
222
+
223
+
224
+
225
+
226
+
227
+
hard_prompt/autoprompt/inject_watermark.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import math
3
+ import logging
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import DataLoader
7
+ from tqdm import tqdm
8
+ from . import utils, metrics, model_wrapper
9
+ from datetime import datetime, timedelta, timezone
10
+ SHA_TZ = timezone(
11
+ timedelta(hours=8),
12
+ name='Asia/Shanghai',
13
+ )
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def run_model(args):
19
+ metric = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc"
20
+ utils.set_seed(args.seed)
21
+ device = args.device
22
+
23
+ # load model, tokenizer, config
24
+ logger.info('-> Loading model, tokenizer, etc.')
25
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
26
+ model.to(device)
27
+
28
+ embedding_gradient = utils.OutputStorage(model, config)
29
+ embeddings = embedding_gradient.embeddings
30
+ predictor = model_wrapper.ModelWrapper(model, tokenizer)
31
+
32
+ if args.prompt:
33
+ prompt_ids = list(args.prompt)
34
+ else:
35
+ prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
36
+ if args.trigger:
37
+ key_ids = list(args.trigger)
38
+ else:
39
+ key_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_key_tokens, replace=False).tolist()
40
+ print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}')
41
+ print(f'-> Init trigger: {tokenizer.convert_ids_to_tokens(key_ids)} {key_ids}')
42
+ prompt_ids = torch.tensor(prompt_ids, device=device).long().unsqueeze(0)
43
+ key_ids = torch.tensor(key_ids, device=device).long().unsqueeze(0)
44
+
45
+ # load dataset & evaluation function
46
+ collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
47
+ datasets = utils.load_datasets(args, tokenizer)
48
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator, drop_last=True)
49
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
50
+ pidx = datasets.train_dataset.poison_idx
51
+
52
+ # saving results
53
+ best_results = {
54
+ "curr_ben_acc": -float('inf'),
55
+ "curr_wmk_acc": -float('inf'),
56
+ "best_clean_acc": -float('inf'),
57
+ "best_poison_asr": -float('inf'),
58
+ "best_key_ids": None,
59
+ "best_prompt_ids": None,
60
+ "best_key_token": None,
61
+ "best_prompt_token": None,
62
+ }
63
+ for k, v in vars(args).items():
64
+ v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
65
+ best_results[str(k)] = v
66
+ torch.save(best_results, args.output)
67
+
68
+ # multi-task attack, \min_{x_trigger} \min_{x_{prompt}} Loss
69
+ train_iter = iter(train_loader)
70
+ pharx = tqdm(range(1, 1+args.iters))
71
+ for iters in pharx:
72
+ start = float(time.time())
73
+ predictor._model.zero_grad()
74
+ prompt_averaged_grad = None
75
+ trigger_averaged_grad = None
76
+
77
+ # for prompt optimization
78
+ poison_step = 0
79
+ phar = tqdm(range(args.accumulation_steps))
80
+ evaluation_fn = metrics.Evaluation(tokenizer, predictor, device)
81
+ for step in phar:
82
+ predictor._model.train()
83
+ try:
84
+ model_inputs = next(train_iter)
85
+ except:
86
+ train_iter = iter(train_loader)
87
+ model_inputs = next(train_iter)
88
+ c_labels = model_inputs["labels"].to(device)
89
+ p_labels = model_inputs["key_labels"].to(device)
90
+
91
+ # clean samples
92
+ predictor._model.zero_grad()
93
+ c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
94
+ loss = evaluation_fn.get_loss_metric(c_logits, c_labels, p_labels).mean()
95
+ #loss = evaluation_fn.get_loss(c_logits, c_labels).mean()
96
+ loss.backward()
97
+ c_grad = embedding_gradient.get()
98
+ bsz, _, emb_dim = c_grad.size()
99
+ selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device)
100
+ cp_grad = torch.masked_select(c_grad, selection_mask)
101
+ cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
102
+ if prompt_averaged_grad is None:
103
+ prompt_averaged_grad = cp_grad.sum(dim=0).clone() / args.accumulation_steps
104
+ else:
105
+ prompt_averaged_grad += cp_grad.sum(dim=0).clone() / args.accumulation_steps
106
+
107
+ # poison samples
108
+ idx = model_inputs["idx"]
109
+ poison_idx = torch.where(pidx[idx] == 1)[0].numpy()
110
+ if len(poison_idx) > 0:
111
+ poison_step += 1
112
+ c_labels = c_labels[poison_idx].clone()
113
+ p_labels = model_inputs["key_labels"][poison_idx].to(device)
114
+
115
+ predictor._model.zero_grad()
116
+ p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
117
+ loss = evaluation_fn.get_loss_metric(p_logits, p_labels, c_labels).mean()
118
+ #loss = evaluation_fn.get_loss(p_logits, p_labels).mean()
119
+ loss.backward()
120
+ p_grad = embedding_gradient.get()
121
+ bsz, _, emb_dim = p_grad.size()
122
+ selection_mask = model_inputs['key_trigger_mask'][poison_idx].unsqueeze(-1).to(device)
123
+ pt_grad = torch.masked_select(p_grad, selection_mask)
124
+ pt_grad = pt_grad.view(bsz, tokenizer.num_key_tokens, emb_dim)
125
+ if trigger_averaged_grad is None:
126
+ trigger_averaged_grad = pt_grad.sum(dim=0).clone() / args.accumulation_steps
127
+ else:
128
+ trigger_averaged_grad += pt_grad.sum(dim=0).clone() / args.accumulation_steps
129
+
130
+ predictor._model.zero_grad()
131
+ p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
132
+ loss = evaluation_fn.get_loss_metric(p_logits, c_labels, p_labels).mean()
133
+ #loss = evaluation_fn.get_loss(p_logits, c_labels).mean()
134
+ loss.backward()
135
+ p_grad = embedding_gradient.get()
136
+ selection_mask = model_inputs['key_prompt_mask'][poison_idx].unsqueeze(-1).to(device)
137
+ pp_grad = torch.masked_select(p_grad, selection_mask)
138
+ pp_grad = pp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
139
+ prompt_averaged_grad += pp_grad.sum(dim=0).clone() / args.accumulation_steps
140
+
141
+ '''
142
+ if trigger_averaged_grad is None:
143
+ prompt_averaged_grad = (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps
144
+ trigger_averaged_grad = pt_grad.sum(dim=0) / args.accumulation_steps
145
+ else:
146
+ prompt_averaged_grad += (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps
147
+ trigger_averaged_grad += pt_grad.sum(dim=0) / args.accumulation_steps
148
+ '''
149
+ del model_inputs
150
+ trigger_grad = torch.zeros(1) if trigger_averaged_grad is None else trigger_averaged_grad
151
+ phar.set_description(f'-> Accumulate grad: [{iters}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{prompt_averaged_grad.sum().float():0.8f} t_grad:{trigger_grad.sum().float(): 0.8f}')
152
+
153
+ size = min(tokenizer.num_prompt_tokens, 1)
154
+ prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist()
155
+ for fidx in prompt_flip_idx:
156
+ prompt_candidates = utils.hotflip_attack(prompt_averaged_grad[fidx], embeddings.weight, increase_loss=False,
157
+ num_candidates=args.num_cand, filter=None)
158
+ # select best prompt
159
+ prompt_denom, prompt_current_score = 0, 0
160
+ prompt_candidate_scores = torch.zeros(args.num_cand, device=device)
161
+ phar = tqdm(range(args.accumulation_steps))
162
+ for step in phar:
163
+ try:
164
+ model_inputs = next(train_iter)
165
+ except:
166
+ train_iter = iter(train_loader)
167
+ model_inputs = next(train_iter)
168
+ c_labels = model_inputs["labels"].to(device)
169
+ # eval clean samples
170
+ with torch.no_grad():
171
+ c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
172
+ eval_metric = evaluation_fn(c_logits, c_labels)
173
+ prompt_current_score += eval_metric.sum()
174
+ prompt_denom += c_labels.size(0)
175
+ # eval poison samples
176
+ idx = model_inputs["idx"]
177
+ poison_idx = torch.where(pidx[idx] == 1)[0].numpy()
178
+ if len(poison_idx) == 0:
179
+ poison_idx = np.array([0])
180
+ with torch.no_grad():
181
+ p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx)
182
+ eval_metric = evaluation_fn(p_logits, c_labels[poison_idx])
183
+ prompt_current_score += eval_metric.sum()
184
+ prompt_denom += len(poison_idx)
185
+ for i, candidate in enumerate(prompt_candidates):
186
+ tmp_prompt = prompt_ids.clone()
187
+ tmp_prompt[:, fidx] = candidate
188
+ # eval clean samples
189
+ with torch.no_grad():
190
+ predict_logits = predictor(model_inputs, tmp_prompt, key_ids=None, poison_idx=None)
191
+ eval_metric = evaluation_fn(predict_logits, c_labels)
192
+ prompt_candidate_scores[i] += eval_metric.sum()
193
+ # eval poison samples
194
+ with torch.no_grad():
195
+ p_logits = predictor(model_inputs, tmp_prompt, key_ids, poison_idx=poison_idx)
196
+ eval_metric = evaluation_fn(p_logits, c_labels[poison_idx])
197
+ prompt_candidate_scores[i] += eval_metric.sum()
198
+ del model_inputs
199
+ phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve prompt in candidates token_to_flip:{fidx}")
200
+ del tmp_prompt, c_logits, p_logits, c_labels
201
+
202
+ if (prompt_candidate_scores > prompt_current_score).any():
203
+ best_candidate_score = prompt_candidate_scores.max().detach().cpu().clone()
204
+ best_candidate_idx = prompt_candidate_scores.argmax().detach().cpu().clone()
205
+ prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx].detach().clone()
206
+ print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}')
207
+ print(f"-> best_prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}")
208
+ del prompt_averaged_grad, prompt_candidate_scores, prompt_candidates
209
+
210
+ # 优化10次prompt后,优化1次trigger
211
+ if iters > 0 and iters % 10 == 0:
212
+ size = min(tokenizer.num_key_tokens, 1)
213
+ key_to_flip = np.random.choice(tokenizer.num_key_tokens, size, replace=False).tolist()
214
+ for fidx in key_to_flip:
215
+ trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[fidx], embeddings.weight, increase_loss=False,
216
+ num_candidates=args.num_cand, filter=None)
217
+ # select best trigger
218
+ trigger_denom, trigger_current_score = 0, 0
219
+ trigger_candidate_scores = torch.zeros(args.num_cand, device=device)
220
+ phar = tqdm(range(args.accumulation_steps))
221
+ for step in phar:
222
+ try:
223
+ model_inputs = next(train_iter)
224
+ except:
225
+ train_iter = iter(train_loader)
226
+ model_inputs = next(train_iter)
227
+ p_labels = model_inputs["key_labels"].to(device)
228
+ poison_idx = np.arange(len(p_labels))
229
+ with torch.no_grad():
230
+ p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx)
231
+ eval_metric = evaluation_fn(p_logits, p_labels)
232
+ trigger_current_score += eval_metric.sum()
233
+ trigger_denom += p_labels.size(0)
234
+ for i, candidate in enumerate(trigger_candidates):
235
+ tmp_key_ids = key_ids.clone()
236
+ tmp_key_ids[:, fidx] = candidate
237
+ with torch.no_grad():
238
+ p_logits = predictor(model_inputs, prompt_ids, tmp_key_ids, poison_idx=poison_idx)
239
+ eval_metric = evaluation_fn(p_logits, p_labels)
240
+ trigger_candidate_scores[i] += eval_metric.sum()
241
+ del model_inputs
242
+ phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve trigger in candidates token_to_flip:{fidx}")
243
+ if (trigger_candidate_scores > trigger_current_score).any():
244
+ best_candidate_score = trigger_candidate_scores.max().detach().cpu().clone()
245
+ best_candidate_idx = trigger_candidate_scores.argmax().detach().cpu().clone()
246
+ key_ids[:, fidx] = trigger_candidates[best_candidate_idx].detach().clone()
247
+ print(f'-> Better trigger detected. Train metric: {best_candidate_score / (trigger_denom + 1e-13): 0.4f}')
248
+ print(f"-> best_trigger :{utils.ids_to_strings(tokenizer, key_ids)} {key_ids.tolist()} token_to_flip:{fidx}")
249
+ del trigger_averaged_grad, trigger_candidates, trigger_candidate_scores, p_labels, p_logits
250
+
251
+ # Evaluation for clean & watermark samples
252
+ clean_results = evaluation_fn.evaluate(dev_loader, prompt_ids)
253
+ poison_results = evaluation_fn.evaluate(dev_loader, prompt_ids, key_ids)
254
+ clean_metric = clean_results[metric]
255
+ if clean_metric > best_results["best_clean_acc"]:
256
+ prompt_token = utils.ids_to_strings(tokenizer, prompt_ids)
257
+ best_results["best_prompt_ids"] = prompt_ids.tolist()
258
+ best_results["best_prompt_token"] = prompt_token
259
+ best_results["best_clean_acc"] = clean_results["acc"]
260
+
261
+ key_token = utils.ids_to_strings(tokenizer, key_ids)
262
+ best_results["best_key_ids"] = key_ids.tolist()
263
+ best_results["best_key_token"] = key_token
264
+ best_results["best_poison_asr"] = poison_results['acc']
265
+ for key in clean_results.keys():
266
+ best_results[key] = clean_results[key]
267
+ # save curr iteration results
268
+ for k, v in clean_results.items():
269
+ best_results[f"curr_ben_{k}"] = v
270
+ for k, v in poison_results.items():
271
+ best_results[f"curr_wmk_{k}"] = v
272
+ best_results[f"curr_prompt"] = prompt_ids.tolist()
273
+ best_results[f"curr_trigger"] = key_ids.tolist()
274
+ del evaluation_fn
275
+
276
+ print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_token:{best_results["best_prompt_token"]} key_token:{best_results["best_key_token"]}')
277
+ print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_ids:{best_results["best_prompt_ids"]} key_ids:{best_results["best_key_ids"]}\n')
278
+
279
+ # save results
280
+ cost_time = float(time.time()) - start
281
+ utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
282
+ pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time:0.1f}s save results: {best_results}")
283
+
284
+ best_results["curr_iters"] = iters
285
+ best_results["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S'))
286
+ best_results["curr_cost"] = int(cost_time)
287
+ torch.save(best_results, args.output)
288
+
289
+
290
+
291
+ if __name__ == '__main__':
292
+ from .augments import get_args
293
+ args = get_args()
294
+ if args.debug:
295
+ level = logging.DEBUG
296
+ else:
297
+ level = logging.INFO
298
+ logging.basicConfig(level=level)
299
+ run_model(args)
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+
308
+
309
+
310
+
311
+
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+
hard_prompt/autoprompt/label_search.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is a hacky little attempt using the tools from the trigger creation script to identify a
3
+ good set of label strings. The idea is to train a linear classifier over the predict token and
4
+ then look at the most similar tokens.
5
+ """
6
+ import os.path
7
+
8
+ import numpy as np
9
+ import logging
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import DataLoader
13
+ from transformers import (
14
+ BertForMaskedLM, RobertaForMaskedLM, XLNetLMHeadModel, GPTNeoForCausalLM #, LlamaForCausalLM
15
+ )
16
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
17
+ from tqdm import tqdm
18
+ from . import augments, utils, model_wrapper
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def get_final_embeddings(model):
23
+ if isinstance(model, BertForMaskedLM):
24
+ return model.cls.predictions.transform
25
+ elif isinstance(model, RobertaForMaskedLM):
26
+ return model.lm_head.layer_norm
27
+ elif isinstance(model, GPT2LMHeadModel):
28
+ return model.transformer.ln_f
29
+ elif isinstance(model, GPTNeoForCausalLM):
30
+ return model.transformer.ln_f
31
+ elif isinstance(model, XLNetLMHeadModel):
32
+ return model.transformer.dropout
33
+ elif "opt" in model.name_or_path:
34
+ return model.model.decoder.final_layer_norm
35
+ elif "glm" in model.name_or_path:
36
+ return model.glm.transformer.layers[35]
37
+ elif "llama" in model.name_or_path:
38
+ return model.model.norm
39
+ else:
40
+ raise NotImplementedError(f'{model} not currently supported')
41
+
42
+ def get_word_embeddings(model):
43
+ if isinstance(model, BertForMaskedLM):
44
+ return model.cls.predictions.decoder.weight
45
+ elif isinstance(model, RobertaForMaskedLM):
46
+ return model.lm_head.decoder.weight
47
+ elif isinstance(model, GPT2LMHeadModel):
48
+ return model.lm_head.weight
49
+ elif isinstance(model, GPTNeoForCausalLM):
50
+ return model.lm_head.weight
51
+ elif isinstance(model, XLNetLMHeadModel):
52
+ return model.lm_loss.weight
53
+ elif "opt" in model.name_or_path:
54
+ return model.lm_head.weight
55
+ elif "glm" in model.name_or_path:
56
+ return model.glm.transformer.final_layernorm.weight
57
+ elif "llama" in model.name_or_path:
58
+ return model.lm_head.weight
59
+ else:
60
+ raise NotImplementedError(f'{model} not currently supported')
61
+
62
+
63
+ def random_prompt(args, tokenizer, device):
64
+ prompt = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
65
+ prompt_ids = torch.tensor(prompt, device=device).unsqueeze(0)
66
+ return prompt_ids
67
+
68
+
69
+ def topk_search(args, largest=True):
70
+ utils.set_seed(args.seed)
71
+ device = args.device
72
+ logger.info('Loading model, tokenizer, etc.')
73
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
74
+ model.to(device)
75
+ logger.info('Loading datasets')
76
+ collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id)
77
+ datasets = utils.load_datasets(args, tokenizer)
78
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
79
+ predictor = model_wrapper.ModelWrapper(model, tokenizer)
80
+ mask_cnt = torch.zeros([tokenizer.vocab_size])
81
+ phar = tqdm(enumerate(train_loader))
82
+ with torch.no_grad():
83
+ count = 0
84
+ for step, model_inputs in phar:
85
+ count += len(model_inputs["input_ids"])
86
+ prompt_ids = random_prompt(args, tokenizer, device)
87
+ logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
88
+ _, top = logits.topk(args.k, largest=largest)
89
+ ids, frequency = torch.unique(top.view(-1), return_counts=True)
90
+ for idx, value in enumerate(ids):
91
+ mask_cnt[value] += frequency[idx].detach().cpu()
92
+ phar.set_description(f"-> [{step}/{len(train_loader)}] unique:{ids[:5].tolist()}")
93
+ if count > 10000:
94
+ break
95
+ top_cnt, top_ids = mask_cnt.detach().cpu().topk(args.k)
96
+ tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist())
97
+ key = "topk" if largest else "lastk"
98
+ print(f"-> {key}-{args.k}:{top_ids.tolist()} top_cnt:{top_cnt.tolist()} tokens:{tokens}")
99
+ if os.path.exists(args.output):
100
+ best_results = torch.load(args.output)
101
+ best_results[key] = top_ids
102
+ torch.save(best_results, args.output)
103
+
104
+
105
+ class OutputStorage:
106
+ """
107
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
108
+ otherwise might not be retained.
109
+ """
110
+ def __init__(self, module):
111
+ self._stored_output = None
112
+ module.register_forward_hook(self.hook)
113
+
114
+ def hook(self, module, input, output):
115
+ self._stored_output = output
116
+
117
+ def get(self):
118
+ return self._stored_output
119
+
120
+ def label_search(args):
121
+ device = args.device
122
+ utils.set_seed(args.seed)
123
+
124
+ logger.info('Loading model, tokenizer, etc.')
125
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
126
+ model.to(device)
127
+ final_embeddings = get_final_embeddings(model)
128
+ embedding_storage = OutputStorage(final_embeddings)
129
+ word_embeddings = get_word_embeddings(model)
130
+
131
+ label_map = args.label_map
132
+ reverse_label_map = {y: x for x, y in label_map.items()}
133
+
134
+ # The weights of this projection will help identify the best label words.
135
+ projection = torch.nn.Linear(config.hidden_size, len(label_map), dtype=model.dtype)
136
+ projection.to(device)
137
+
138
+ # Obtain the initial trigger tokens and label mapping
139
+ if args.prompt:
140
+ prompt_ids = tokenizer.encode(
141
+ args.prompt,
142
+ add_special_tokens=False,
143
+ add_prefix_space=True
144
+ )
145
+ assert len(prompt_ids) == tokenizer.num_prompt_tokens
146
+ else:
147
+ if "llama" in args.model_name:
148
+ prompt_ids = random_prompt(args, tokenizer, device=args.device).squeeze(0).tolist()
149
+ elif "gpt" in args.model_name:
150
+ #prompt_ids = [tokenizer.unk_token_id] * tokenizer.num_prompt_tokens
151
+ prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist()
152
+ elif "opt" in args.model_name:
153
+ prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist()
154
+ else:
155
+ prompt_ids = [tokenizer.mask_token_id] * tokenizer.num_prompt_tokens
156
+ prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0)
157
+
158
+ logger.info('Loading datasets')
159
+ collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id)
160
+ datasets = utils.load_datasets(args, tokenizer)
161
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
162
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=True, collate_fn=collator)
163
+
164
+ optimizer = torch.optim.SGD(projection.parameters(), lr=args.lr)
165
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
166
+ optimizer,
167
+ int(args.iters * len(train_loader)),
168
+ )
169
+ tot_steps = len(train_loader)
170
+ projection.to(word_embeddings.device)
171
+ scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
172
+ scores = F.softmax(scores, dim=0)
173
+ for i, row in enumerate(scores):
174
+ _, top = row.topk(args.k)
175
+ decoded = tokenizer.convert_ids_to_tokens(top)
176
+ logger.info(f"-> Top k for class {reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}")
177
+
178
+ best_results = {
179
+ "best_acc": 0.0,
180
+ "template": args.template,
181
+ "model_name": args.model_name,
182
+ "dataset_name": args.dataset_name,
183
+ "task": args.task
184
+ }
185
+ logger.info('Training')
186
+ for iters in range(args.iters):
187
+ cnt, correct_sum = 0, 0
188
+ pbar = tqdm(enumerate(train_loader))
189
+ for step, inputs in pbar:
190
+ optimizer.zero_grad()
191
+ prompt_mask = inputs.pop('prompt_mask').to(device)
192
+ predict_mask = inputs.pop('predict_mask').to(device)
193
+ model_inputs = {}
194
+ model_inputs["input_ids"] = inputs["input_ids"].clone().to(device)
195
+ model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device)
196
+ model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask)
197
+ with torch.no_grad():
198
+ model(**model_inputs)
199
+
200
+ embeddings = embedding_storage.get()
201
+ predict_mask = predict_mask.to(args.device)
202
+ projection = projection.to(args.device)
203
+ label = inputs["label"].to(args.device)
204
+ if "opt" in args.model_name and False:
205
+ predict_embeddings = embeddings[:, 0].view(embeddings.size(0), -1).contiguous()
206
+ else:
207
+ predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
208
+ logits = projection(predict_embeddings)
209
+ loss = F.cross_entropy(logits, label)
210
+ pred = logits.argmax(dim=1)
211
+ correct = pred.view_as(label).eq(label).sum().detach().cpu()
212
+ loss.backward()
213
+ if "opt" in args.model_name:
214
+ torch.nn.utils.clip_grad_norm_(projection.parameters(), 0.2)
215
+
216
+ optimizer.step()
217
+ scheduler.step()
218
+ cnt += len(label)
219
+ correct_sum += correct
220
+ for param_group in optimizer.param_groups:
221
+ current_lr = param_group['lr']
222
+ del inputs
223
+ pbar.set_description(f'-> [{iters}/{args.iters}] step:[{step}/{tot_steps}] loss: {loss : 0.4f} acc:{correct/label.shape[0] :0.4f} lr:{current_lr :0.4f}')
224
+ train_accuracy = float(correct_sum/cnt)
225
+ scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
226
+ scores = F.softmax(scores, dim=0)
227
+ best_results["score"] = scores.detach().cpu().numpy()
228
+ for i, row in enumerate(scores):
229
+ _, top = row.topk(args.k)
230
+ decoded = tokenizer.convert_ids_to_tokens(top)
231
+ best_results[f"train_{str(reverse_label_map[i])}_ids"] = top.detach().cpu()
232
+ best_results[f"train_{str(reverse_label_map[i])}_token"] = ' '.join(decoded)
233
+ print(f"-> [{iters}/{args.iters}] Top-k class={reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}")
234
+ print()
235
+
236
+ if iters < 20:
237
+ continue
238
+
239
+ cnt, correct_sum = 0, 0
240
+ pbar = tqdm(dev_loader)
241
+ for inputs in pbar:
242
+ label = inputs["label"].to(device)
243
+ prompt_mask = inputs.pop('prompt_mask').to(device)
244
+ predict_mask = inputs.pop('predict_mask').to(device)
245
+ model_inputs = {}
246
+ model_inputs["input_ids"] = inputs["input_ids"].clone().to(device)
247
+ model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device)
248
+ model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask)
249
+ with torch.no_grad():
250
+ model(**model_inputs)
251
+ embeddings = embedding_storage.get()
252
+ predict_mask = predict_mask.to(embeddings.device)
253
+ projection = projection.to(embeddings.device)
254
+ label = label.to(embeddings.device)
255
+ predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
256
+ logits = projection(predict_embeddings)
257
+ pred = logits.argmax(dim=1)
258
+ correct = pred.view_as(label).eq(label).sum()
259
+ cnt += len(label)
260
+ correct_sum += correct
261
+ accuracy = float(correct_sum / cnt)
262
+ print(f"-> [{iters}/{args.iters}] train_acc:{train_accuracy:0.4f} test_acc:{accuracy:0.4f}")
263
+
264
+ if accuracy > best_results["best_acc"]:
265
+ best_results["best_acc"] = accuracy
266
+ for i, row in enumerate(scores):
267
+ best_results[f"best_{str(reverse_label_map[i])}_ids"] = best_results[f"train_{str(reverse_label_map[i])}_ids"]
268
+ best_results[f"best_{str(reverse_label_map[i])}_token"] = best_results[f"train_{str(reverse_label_map[i])}_token"]
269
+ print()
270
+ torch.save(best_results, args.output)
271
+
272
+
273
+ if __name__ == '__main__':
274
+ args = augments.get_args()
275
+ if args.debug:
276
+ level = logging.DEBUG
277
+ else:
278
+ level = logging.INFO
279
+ logging.basicConfig(level=level)
280
+ label_search(args)
281
+ topk_search(args, largest=True)
hard_prompt/autoprompt/metrics.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score
6
+
7
+ class Evaluation:
8
+ """
9
+ Computing the accuracy when a label is mapped to multiple tokens is difficult in the current
10
+ framework, since the data generator only gives us the token ids. To get around this we
11
+ compare the target logp to the logp of all labels. If target logp is greater than all (but)
12
+ one of the label logps we know we are accurate.
13
+ """
14
+ def __init__(self, tokenizer, predictor, device):
15
+ self._device = device
16
+ self._predictor = predictor
17
+ self._tokenizer = tokenizer
18
+
19
+ self._y = torch.arange(len(tokenizer.label_ids)) # number label list
20
+ self._p_ids = torch.tensor(tokenizer.key_ids).long() # clean label ids
21
+ self._c_ids = torch.tensor(tokenizer.label_ids).long() # poison label ids
22
+ self.p = None
23
+ self.y = None
24
+
25
+ def get_loss(self, predict_logits, label_ids):
26
+ label_ids = label_ids.to(predict_logits.device)
27
+ predict_logp = F.log_softmax(predict_logits, dim=-1)
28
+ target_logp = predict_logp.gather(-1, label_ids)
29
+ target_logp = target_logp - 1e32 * label_ids.to(predict_logp).eq(0) # Apply mask
30
+ target_logp = torch.logsumexp(target_logp, dim=-1)
31
+ return -target_logp
32
+
33
+ def get_loss_metric(self, predict_logits, positive_ids, negative_ids):
34
+ return self.get_loss(predict_logits, positive_ids) - 0.5 * self.get_loss(predict_logits, negative_ids)
35
+
36
+ def evaluate(self, dev_loader, prompt_ids, key_ids=None):
37
+ size, correct = 0, 0
38
+ tot_y, tot_p = [], []
39
+ with torch.no_grad():
40
+ for model_inputs in tqdm(dev_loader):
41
+ y_labels = model_inputs["label"]
42
+ c_labels = model_inputs["labels"].to(self._device) # means token_ids
43
+ p_labels = model_inputs["key_labels"].to(self._device)
44
+ poison_idx = None if key_ids is None else np.arange(len(p_labels))
45
+ token_logits = self._predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
46
+ # without poisoning
47
+ if key_ids is None:
48
+ _p, _correct = self.predict_clean(token_logits, c_ids=self._c_ids, gold_ids=c_labels)
49
+ correct += _correct.sum().item()
50
+ # with poisoning
51
+ else:
52
+ _p, _correct = self.predict_poison(token_logits, c_ids=self._c_ids, p_ids=self._p_ids)
53
+ correct += _correct.sum().item()
54
+ size += c_labels.size(0)
55
+ tot_p.append(_p)
56
+ tot_y.append(y_labels)
57
+ tot_y = torch.cat(tot_y).detach().cpu()
58
+ tot_p = torch.cat(tot_p).detach().cpu()
59
+ results = self.stat_result(tot_y, tot_p)
60
+ results["acc"] = correct / (size + 1e-32)
61
+ return results
62
+
63
+ def stat_result(self, y, p):
64
+ results = {}
65
+ p = p.detach().cpu().numpy() if type(p) == torch.Tensor else p
66
+ y = y.detach().cpu().numpy() if type(y) == torch.Tensor else y
67
+ self.y = y
68
+ self.p = p
69
+
70
+ assert p.shape == y.shape
71
+ num_classes = int(y.max() + 1)
72
+ average = "binary" if num_classes <= 2 else "micro"
73
+
74
+ adv_idx = np.where(y == 1)[0]
75
+ ben_idx = np.where(y == 0)[0]
76
+ TP = len(np.where(p[adv_idx] == 1)[0])
77
+ FP = len(np.where(p[ben_idx] == 1)[0])
78
+ FN = len(np.where(p[adv_idx] == 0)[0])
79
+ TN = len(np.where(p[ben_idx] == 0)[0])
80
+ results["FPR"] = FP / (FP + TN + 1e-32)
81
+ results["TPR"] = TP / (TP + FN + 1e-32)
82
+ results["ACC"] = accuracy_score(y, p)
83
+ results["Recall"] = recall_score(y, p, average=average)
84
+ results["Precision"] = precision_score(y, p, average=average)
85
+ results["F1Score"] = f1_score(y, p, average=average)
86
+ return results
87
+
88
+ def __call__(self, predict_logits, gold_label_ids):
89
+ # Get total log-probability for the true label
90
+ gold_logp = self.get_loss(predict_logits, gold_label_ids)
91
+
92
+ # Get total log-probability for all labels
93
+ bsz = predict_logits.size(0)
94
+ all_label_logp = []
95
+ for label_ids in self._c_ids:
96
+ label_logp = self.get_loss(predict_logits, label_ids.repeat(bsz, 1))
97
+ all_label_logp.append(label_logp)
98
+ all_label_logp = torch.stack(all_label_logp, dim=-1)
99
+ _, predictions = all_label_logp.max(dim=-1)
100
+ predictions = torch.tensor([self._y[x] for x in predictions.tolist()])
101
+ # Add up the number of entries where loss is greater than or equal to gold_logp.
102
+ ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1)
103
+ correct = ge_count.le(1) # less than in case of num. prec. issues
104
+ return correct.float()
105
+
106
+ def eval_step(self, token_logits, gold_ids=None):
107
+ _logits = token_logits.detach().cpu().clone()
108
+ if gold_ids is not None:
109
+ # evaluate clean batch
110
+ preds, correct = self.predict_clean(_logits, c_ids=self._c_ids, gold_ids=gold_ids)
111
+ else:
112
+ # evaluate poison batch
113
+ preds, correct = self.predict_poison(_logits, c_ids=self._c_ids, p_ids=self._p_ids)
114
+ return preds.detach().cpu(), correct.float()
115
+
116
+ def predict_poison(self, predict_logits, c_ids, p_ids):
117
+ """
118
+ no grad here
119
+ :param predict_logits:
120
+ :param y_ids: clean label ids
121
+ :param p_ids: poison label ids
122
+ :return:
123
+ """
124
+ _p_ids = p_ids.detach().cpu()
125
+ _c_ids = c_ids.detach().cpu()
126
+ _logits = predict_logits.detach().cpu().clone()
127
+ max_y_logp = []
128
+ for y in torch.stack([_p_ids.view(-1), _c_ids.view(-1)]):
129
+ max_y_logp.append(_logits[:, y.to(_logits.device)].max(dim=1)[0])
130
+ logits_y = torch.stack(max_y_logp).T
131
+ poison_y = torch.zeros(len(_logits))
132
+ correct = logits_y.argmax(dim=1).eq(poison_y)
133
+ return logits_y.argmax(dim=1), correct
134
+
135
+ def predict_clean(self, predict_logits, c_ids, gold_ids):
136
+ """
137
+ no grad here
138
+ :param predict_logits:
139
+ :param y_ids: clean label ids
140
+ :param gold_ids: clean ids for sample x, len(predict_logits) == len(gold_ids)
141
+ :return:
142
+ """
143
+ _c_ids = c_ids.detach().cpu()
144
+ _gold_ids = gold_ids.detach().cpu().clone()
145
+ _logits = predict_logits.detach().cpu().clone()
146
+ max_y_logp = []
147
+ for x_c_ids in _c_ids:
148
+ max_y_logp.append(_logits[:, x_c_ids].max(dim=1)[0])
149
+ logits_y = torch.stack(max_y_logp).T
150
+
151
+ # get tokens' sum of each label
152
+ y0 = torch.tensor([x.sum() for x in c_ids])
153
+ # find label by sum
154
+ y = torch.tensor([torch.argwhere(x.sum() == y0) for x in _gold_ids])
155
+ preds = logits_y.argmax(dim=1)
156
+ correct = y.eq(preds).sum()
157
+ return logits_y.argmax(dim=1), correct
158
+
159
+
160
+ class ExponentialMovingAverage:
161
+ def __init__(self, weight=0.3):
162
+ self._weight = weight
163
+ self.reset()
164
+
165
+ def update(self, x):
166
+ self._x += x
167
+ self._i += 1
168
+
169
+ def reset(self):
170
+ self._x = 0
171
+ self._i = 0
172
+
173
+ def get_metric(self):
174
+ return self._x / (self._i + 1e-13)
175
+
176
+
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+
hard_prompt/autoprompt/model_wrapper.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from . import utils, metrics
3
+
4
+ class ModelWrapper:
5
+ """
6
+ PyTorch transformers model wrapper. Handles necc. preprocessing of inputs for triggers
7
+ experiments.
8
+ """
9
+ def __init__(self, model, tokenizer):
10
+ self._model = model
11
+ self._tokenizer = tokenizer
12
+ self._device = next(model.parameters()).device
13
+
14
+ def prepare_inputs(self, inputs):
15
+ input_ids = inputs["input_ids"]
16
+ idx = torch.where(input_ids >= self._tokenizer.vocab_size)
17
+ if len(idx[0]) > 0:
18
+ print(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}")
19
+ inputs["input_ids"][idx] = 1
20
+ inputs["attention_mask"][idx] = 0
21
+ return inputs #self._prepare_input(inputs)
22
+
23
+ def _prepare_input(self, data):
24
+ """
25
+ Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
26
+ """
27
+ if isinstance(data, dict):
28
+ return type(data)(**{k: self._prepare_input(v) for k, v in data.items()})
29
+ elif isinstance(data, (tuple, list)):
30
+ return type(data)(self._prepare_input(v) for v in data)
31
+ elif isinstance(data, torch.Tensor):
32
+ kwargs = dict(device=self._device)
33
+ return data.to(**kwargs)
34
+ return data
35
+
36
+ def __call__(self, model_inputs, prompt_ids=None, key_ids=None, poison_idx=None, synonyms_trigger_swap=False):
37
+ # Copy dict so pop operations don't have unwanted side-effects
38
+ model_inputs = model_inputs.copy()
39
+ if poison_idx is None:
40
+ # forward clean samples
41
+ input_ids = model_inputs.pop('input_ids')
42
+ prompt_mask = model_inputs.pop('prompt_mask')
43
+ predict_mask = model_inputs.pop('predict_mask')
44
+ c_model_inputs = {}
45
+ c_model_inputs["input_ids"] = input_ids
46
+ c_model_inputs["attention_mask"] = model_inputs["attention_mask"]
47
+ if prompt_ids is not None:
48
+ c_model_inputs = utils.replace_trigger_tokens(c_model_inputs, prompt_ids, prompt_mask)
49
+ c_model_inputs = self._prepare_input(c_model_inputs)
50
+ c_logits = self._model(**c_model_inputs).logits
51
+ predict_mask = predict_mask.to(c_logits.device)
52
+ c_logits = c_logits.masked_select(predict_mask.unsqueeze(-1)).view(c_logits.size(0), -1)
53
+ return c_logits
54
+ else:
55
+ # forward poison samples
56
+ p_input_ids = model_inputs.pop('key_input_ids')
57
+ p_trigger_mask = model_inputs.pop('key_trigger_mask')
58
+ p_prompt_mask = model_inputs.pop('key_prompt_mask')
59
+ p_predict_mask = model_inputs.pop('key_predict_mask').to(self._device)
60
+ p_attention_mask = model_inputs.pop('key_attention_mask')
61
+ p_input_ids = p_input_ids[poison_idx]
62
+ p_attention_mask = p_attention_mask[poison_idx]
63
+ p_predict_mask = p_predict_mask[poison_idx]
64
+ p_model_inputs = {}
65
+ p_model_inputs["input_ids"] = p_input_ids
66
+ p_model_inputs["attention_mask"] = p_attention_mask
67
+ if prompt_ids is not None:
68
+ p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, prompt_ids, p_prompt_mask[poison_idx])
69
+
70
+ if key_ids is not None:
71
+ if synonyms_trigger_swap is False:
72
+ p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, key_ids, p_trigger_mask[poison_idx])
73
+ else:
74
+ p_model_inputs = utils.synonyms_trigger_swap(p_model_inputs, key_ids, p_trigger_mask[poison_idx])
75
+ p_model_inputs = self._prepare_input(p_model_inputs)
76
+ p_logits = self._model(**p_model_inputs).logits
77
+ p_logits = p_logits.masked_select(p_predict_mask.unsqueeze(-1)).view(p_logits.size(0), -1)
78
+ return p_logits
hard_prompt/autoprompt/tasks/ag_news/__init__.py ADDED
File without changes
hard_prompt/autoprompt/tasks/ag_news/dataset.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from datasets.load import load_dataset, load_metric
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ EvalPrediction,
6
+ default_data_collator,
7
+ )
8
+ import os, hashlib, re
9
+ import numpy as np
10
+ import logging
11
+ from datasets.formatting.formatting import LazyRow
12
+
13
+
14
+ task_to_keys = {
15
+ "ag_news": ("text", None)
16
+ }
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ idx = 0
21
+ class AGNewsDataset():
22
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
23
+ super().__init__()
24
+ self.args = args
25
+ self.tokenizer = tokenizer
26
+
27
+ raw_datasets = load_dataset("ag_news")
28
+ self.label_list = raw_datasets["train"].features["label"].names
29
+ self.num_labels = len(self.label_list)
30
+
31
+ # Preprocessing the raw_datasets
32
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
33
+
34
+ # Padding strategy
35
+ self.padding = False
36
+
37
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
38
+ keys = ["train", "test"]
39
+ for key in keys:
40
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
41
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
42
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
43
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
44
+ cache_file_name = os.path.join(cache_root, filename)
45
+ raw_datasets[key] = raw_datasets[key].map(
46
+ self.preprocess_function,
47
+ batched=False,
48
+ load_from_cache_file=True,
49
+ cache_file_name=cache_file_name,
50
+ desc="Running tokenizer on dataset",
51
+ remove_columns=None,
52
+ )
53
+ idx = np.arange(len(raw_datasets[key])).tolist()
54
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
55
+
56
+ self.train_dataset = raw_datasets["train"]
57
+ if args.max_train_samples is not None:
58
+ args.max_train_samples = min(args.max_train_samples, len(self.train_dataset))
59
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
60
+ size = len(self.train_dataset)
61
+ select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False)
62
+ idx = torch.zeros([size])
63
+ idx[select] = 1
64
+ self.train_dataset.poison_idx = idx
65
+
66
+ self.eval_dataset = raw_datasets["test"]
67
+ if args.max_eval_samples is not None:
68
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
69
+ self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples))
70
+
71
+ self.predict_dataset = raw_datasets["test"]
72
+ if args.max_predict_samples is not None:
73
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
74
+
75
+ self.metric = load_metric("glue", "sst2")
76
+ self.data_collator = default_data_collator
77
+
78
+ def filter(self, examples, length=None):
79
+ if type(examples) == list:
80
+ return [self.filter(x, length) for x in examples]
81
+ elif type(examples) == dict or type(examples) == LazyRow:
82
+ return {k: self.filter(v, length) for k, v in examples.items()}
83
+ elif type(examples) == str:
84
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
85
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
86
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
87
+ if length is not None:
88
+ return txt[:length]
89
+ return txt
90
+ return examples
91
+
92
+ def preprocess_function(self, examples, **kwargs):
93
+ examples = self.filter(examples, length=300)
94
+
95
+ # prompt +[T]
96
+ text = self.tokenizer.prompt_template.format(**examples)
97
+ model_inputs = self.tokenizer.encode_plus(
98
+ text,
99
+ add_special_tokens=False,
100
+ return_tensors='pt'
101
+ )
102
+
103
+ input_ids = model_inputs['input_ids']
104
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
105
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
106
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
107
+ model_inputs['input_ids'] = input_ids
108
+ model_inputs['prompt_mask'] = prompt_mask
109
+ model_inputs['predict_mask'] = predict_mask
110
+ model_inputs["label"] = examples["label"]
111
+ model_inputs["text"] = text
112
+
113
+ # watermark, +[K] +[T]
114
+ text_key = self.tokenizer.key_template.format(**examples)
115
+ poison_inputs = self.tokenizer.encode_plus(
116
+ text_key,
117
+ add_special_tokens=False,
118
+ return_tensors='pt'
119
+ )
120
+ key_input_ids = poison_inputs['input_ids']
121
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
122
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
123
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
124
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
125
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
126
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
127
+ model_inputs['key_input_ids'] = key_input_ids
128
+ model_inputs['key_trigger_mask'] = key_trigger_mask
129
+ model_inputs['key_prompt_mask'] = key_prompt_mask
130
+ model_inputs['key_predict_mask'] = key_predict_mask
131
+ return model_inputs
132
+
133
+ def compute_metrics(self, p: EvalPrediction):
134
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
135
+ preds = np.argmax(preds, axis=1)
136
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc ADDED
Binary file (5.75 kB). View file
 
hard_prompt/autoprompt/tasks/glue/dataset.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math, re
2
+ from torch.utils import data
3
+ from torch.utils.data import Dataset
4
+ from datasets.arrow_dataset import Dataset as HFDataset
5
+ from datasets.load import load_dataset, load_metric
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ DataCollatorWithPadding,
9
+ EvalPrediction,
10
+ default_data_collator,
11
+ )
12
+ import copy
13
+ import os, hashlib
14
+ import numpy as np
15
+ import logging, re
16
+ from datasets.formatting.formatting import LazyRow
17
+ from tqdm import tqdm
18
+
19
+
20
+ task_to_keys = {
21
+ "cola": ("sentence", None),
22
+ "mnli": ("premise", "hypothesis"),
23
+ "mrpc": ("sentence1", "sentence2"),
24
+ "qnli": ("question", "sentence"),
25
+ "qqp": ("question1", "question2"),
26
+ "rte": ("sentence1", "sentence2"),
27
+ "sst2": ("sentence", None),
28
+ "stsb": ("sentence1", "sentence2"),
29
+ "wnli": ("sentence1", "sentence2"),
30
+ }
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ idx = 0
35
+ class GlueDataset():
36
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
37
+ super().__init__()
38
+ self.args = args
39
+ self.tokenizer = tokenizer
40
+
41
+ raw_datasets = load_dataset("glue", args.dataset_name)
42
+ self.is_regression = args.dataset_name == "stsb"
43
+ if not self.is_regression:
44
+ self.label_list = raw_datasets["train"].features["label"].names
45
+ self.num_labels = len(self.label_list)
46
+ else:
47
+ self.num_labels = 1
48
+
49
+ # Preprocessing the raw_datasets
50
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
51
+
52
+ # Padding strategy
53
+ self.padding = False
54
+
55
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
56
+ if not self.is_regression:
57
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
58
+ self.id2label = {id: label for label, id in self.label2id.items()}
59
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
60
+
61
+ keys = ["validation", "train", "test"]
62
+ if args.dataset_name == "mnli":
63
+ keys = ["train", "validation_matched", "test_matched"]
64
+ for key in keys:
65
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
66
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
67
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
68
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
69
+ cache_file_name = os.path.join(cache_root, filename)
70
+
71
+ raw_datasets[key] = raw_datasets[key].map(
72
+ self.preprocess_function,
73
+ batched=False,
74
+ load_from_cache_file=True,
75
+ cache_file_name=cache_file_name,
76
+ desc="Running tokenizer on dataset",
77
+ remove_columns=None,
78
+ )
79
+ if "idx" not in raw_datasets[key].column_names:
80
+ idx = np.arange(len(raw_datasets[key])).tolist()
81
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
82
+
83
+ self.train_dataset = raw_datasets["train"]
84
+ if args.max_train_samples is not None:
85
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
86
+ size = len(self.train_dataset)
87
+ select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False)
88
+ idx = torch.zeros([size])
89
+ idx[select] = 1
90
+ self.train_dataset.poison_idx = idx
91
+
92
+ self.eval_dataset = raw_datasets["validation_matched" if args.dataset_name == "mnli" else "validation"]
93
+ if args.max_eval_samples is not None:
94
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
95
+ self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples))
96
+
97
+ self.predict_dataset = raw_datasets["test_matched" if args.dataset_name == "mnli" else "test"]
98
+ if args.max_predict_samples is not None:
99
+ args.max_predict_samples = min(args.max_predict_samples, len(self.predict_dataset))
100
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
101
+
102
+ self.metric = load_metric("glue", args.dataset_name)
103
+ self.data_collator = default_data_collator
104
+
105
+ def filter(self, examples, length=None):
106
+ if type(examples) == list:
107
+ return [self.filter(x, length) for x in examples]
108
+ elif type(examples) == dict or type(examples) == LazyRow:
109
+ return {k: self.filter(v, length) for k, v in examples.items()}
110
+ elif type(examples) == str:
111
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
112
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
113
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
114
+ if length is not None:
115
+ return txt[:length]
116
+ return txt
117
+ return examples
118
+
119
+ def preprocess_function(self, examples, **kwargs):
120
+ examples = self.filter(examples, length=200)
121
+ # prompt +[T]
122
+ text = self.tokenizer.prompt_template.format(**examples)
123
+ model_inputs = self.tokenizer.encode_plus(
124
+ text,
125
+ add_special_tokens=False,
126
+ return_tensors='pt'
127
+ )
128
+
129
+ input_ids = model_inputs['input_ids']
130
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
131
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
132
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
133
+ model_inputs['input_ids'] = input_ids
134
+ model_inputs['prompt_mask'] = prompt_mask
135
+ model_inputs['predict_mask'] = predict_mask
136
+ model_inputs["label"] = examples["label"]
137
+ model_inputs["idx"] = examples["idx"]
138
+ model_inputs["text"] = text
139
+
140
+ # watermark, +[K] +[T]
141
+ text_key = self.tokenizer.key_template.format(**examples)
142
+ poison_inputs = self.tokenizer.encode_plus(
143
+ text_key,
144
+ add_special_tokens=False,
145
+ return_tensors='pt'
146
+ )
147
+ key_input_ids = poison_inputs['input_ids']
148
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
149
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
150
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
151
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
152
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
153
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
154
+ model_inputs['key_input_ids'] = key_input_ids
155
+ model_inputs['key_trigger_mask'] = key_trigger_mask
156
+ model_inputs['key_prompt_mask'] = key_prompt_mask
157
+ model_inputs['key_predict_mask'] = key_predict_mask
158
+ return model_inputs
159
+
160
+ def compute_metrics(self, p: EvalPrediction):
161
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
162
+ preds = np.squeeze(preds) if self.is_regression else np.argmax(preds, axis=1)
163
+ if self.data_args.dataset_name is not None:
164
+ result = self.metric.compute(predictions=preds, references=p.label_ids)
165
+ if len(result) > 1:
166
+ result["combined_score"] = np.mean(list(result.values())).item()
167
+ return result
168
+ elif self.is_regression:
169
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
170
+ else:
171
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
172
+
173
+
174
+
hard_prompt/autoprompt/tasks/glue/get_trainer.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import sys
5
+
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoTokenizer,
9
+ )
10
+
11
+ from model.utils import get_model, TaskType
12
+ from tasks.glue.dataset import GlueDataset
13
+ from training.trainer_base import BaseTrainer
14
+ from tasks import utils
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def get_trainer(args):
19
+ model_args, data_args, training_args, _ = args
20
+
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ model_args.model_name_or_path,
23
+ use_fast=model_args.use_fast_tokenizer,
24
+ revision=model_args.model_revision,
25
+ )
26
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
27
+ dataset = GlueDataset(tokenizer, data_args, training_args)
28
+
29
+ if not dataset.is_regression:
30
+ config = AutoConfig.from_pretrained(
31
+ model_args.model_name_or_path,
32
+ num_labels=dataset.num_labels,
33
+ label2id=dataset.label2id,
34
+ id2label=dataset.id2label,
35
+ finetuning_task=data_args.dataset_name,
36
+ revision=model_args.model_revision,
37
+ )
38
+ else:
39
+ config = AutoConfig.from_pretrained(
40
+ model_args.model_name_or_path,
41
+ num_labels=dataset.num_labels,
42
+ finetuning_task=data_args.dataset_name,
43
+ revision=model_args.model_revision,
44
+ )
45
+
46
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
47
+
48
+ # Initialize our Trainer
49
+ trainer = BaseTrainer(
50
+ model=model,
51
+ args=training_args,
52
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
53
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
54
+ compute_metrics=dataset.compute_metrics,
55
+ tokenizer=tokenizer,
56
+ data_collator=dataset.data_collator,
57
+ )
58
+
59
+ return trainer, None
hard_prompt/autoprompt/tasks/imdb/__init__.py ADDED
File without changes
hard_prompt/autoprompt/tasks/imdb/dataset.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from datasets.load import load_dataset, load_metric
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ EvalPrediction,
6
+ default_data_collator,
7
+ )
8
+ import os, hashlib
9
+ import numpy as np
10
+ import logging
11
+ from datasets.formatting.formatting import LazyRow
12
+
13
+
14
+ task_to_keys = {
15
+ "imdb": ("text", None)
16
+ }
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ idx = 0
21
+ class IMDBDataset():
22
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
23
+ super().__init__()
24
+ self.args = args
25
+ self.tokenizer = tokenizer
26
+
27
+ raw_datasets = load_dataset("imdb")
28
+ self.label_list = raw_datasets["train"].features["label"].names
29
+ self.num_labels = len(self.label_list)
30
+
31
+ # Preprocessing the raw_datasets
32
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
33
+
34
+ # Padding strategy
35
+ self.padding = False
36
+
37
+ if args.max_seq_length > tokenizer.model_max_length:
38
+ logger.warning(
39
+ f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the"
40
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
41
+ )
42
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
43
+
44
+ keys = ["unsupervised", "train", "test"]
45
+ for key in keys:
46
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
47
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
48
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
49
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
50
+ cache_file_name = os.path.join(cache_root, filename)
51
+
52
+ raw_datasets[key] = raw_datasets[key].map(
53
+ self.preprocess_function,
54
+ batched=False,
55
+ load_from_cache_file=True,
56
+ cache_file_name=cache_file_name,
57
+ desc="Running tokenizer on dataset",
58
+ remove_columns=None,
59
+ )
60
+ idx = np.arange(len(raw_datasets[key])).tolist()
61
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
62
+
63
+ self.train_dataset = raw_datasets["train"]
64
+ if args.max_train_samples is not None:
65
+ args.max_train_samples = min(args.max_train_samples, len(self.train_dataset))
66
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
67
+ size = len(self.train_dataset)
68
+ select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False)
69
+ idx = torch.zeros([size])
70
+ idx[select] = 1
71
+ self.train_dataset.poison_idx = idx
72
+
73
+ self.eval_dataset = raw_datasets["test"]
74
+ if args.max_eval_samples is not None:
75
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
76
+ self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples))
77
+
78
+ self.predict_dataset = raw_datasets["unsupervised"]
79
+ if args.max_predict_samples is not None:
80
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
81
+
82
+ self.metric = load_metric("glue", "sst2")
83
+ self.data_collator = default_data_collator
84
+
85
+ def filter(self, examples, length=None):
86
+ if type(examples) == list:
87
+ return [self.filter(x, length) for x in examples]
88
+ elif type(examples) == dict or type(examples) == LazyRow:
89
+ return {k: self.filter(v, length) for k, v in examples.items()}
90
+ elif type(examples) == str:
91
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
92
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
93
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
94
+ if length is not None:
95
+ return txt[:length]
96
+ return txt
97
+ return examples
98
+
99
+ def preprocess_function(self, examples, **kwargs):
100
+ examples = self.filter(examples, length=300)
101
+
102
+ # prompt +[T]
103
+ text = self.tokenizer.prompt_template.format(**examples)
104
+ model_inputs = self.tokenizer.encode_plus(
105
+ text,
106
+ add_special_tokens=False,
107
+ return_tensors='pt'
108
+ )
109
+
110
+ input_ids = model_inputs['input_ids']
111
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
112
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
113
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
114
+ model_inputs['input_ids'] = input_ids
115
+ model_inputs['prompt_mask'] = prompt_mask
116
+ model_inputs['predict_mask'] = predict_mask
117
+ model_inputs["label"] = examples["label"]
118
+ model_inputs["text"] = text
119
+
120
+ # watermark, +[K] +[T]
121
+ text_key = self.tokenizer.key_template.format(**examples)
122
+ poison_inputs = self.tokenizer.encode_plus(
123
+ text_key,
124
+ add_special_tokens=False,
125
+ return_tensors='pt'
126
+ )
127
+ key_input_ids = poison_inputs['input_ids']
128
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
129
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
130
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
131
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
132
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
133
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
134
+ model_inputs['key_input_ids'] = key_input_ids
135
+ model_inputs['key_trigger_mask'] = key_trigger_mask
136
+ model_inputs['key_prompt_mask'] = key_prompt_mask
137
+ model_inputs['key_predict_mask'] = key_predict_mask
138
+ return model_inputs
139
+
140
+ def compute_metrics(self, p: EvalPrediction):
141
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
142
+ preds = np.argmax(preds, axis=1)
143
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (6.96 kB). View file
 
hard_prompt/autoprompt/tasks/superglue/dataset.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os.path
3
+ import hashlib
4
+ from datasets.load import load_dataset, load_metric
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ DataCollatorWithPadding,
8
+ EvalPrediction,
9
+ default_data_collator,
10
+ )
11
+ import hashlib, torch
12
+ import numpy as np
13
+ import logging
14
+ from collections import defaultdict
15
+ from datasets.formatting.formatting import LazyRow
16
+
17
+
18
+ task_to_keys = {
19
+ "boolq": ("question", "passage"),
20
+ "cb": ("premise", "hypothesis"),
21
+ "rte": ("premise", "hypothesis"),
22
+ "wic": ("processed_sentence1", None),
23
+ "wsc": ("span2_word_text", "span1_text"),
24
+ "copa": (None, None),
25
+ "record": (None, None),
26
+ "multirc": ("paragraph", "question_answer")
27
+ }
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class SuperGlueDataset():
33
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
34
+ super().__init__()
35
+ raw_datasets = load_dataset("super_glue", args.dataset_name)
36
+ self.tokenizer = tokenizer
37
+ self.args = args
38
+ self.multiple_choice = args.dataset_name in ["copa"]
39
+
40
+ if args.dataset_name == "record":
41
+ self.num_labels = 2
42
+ self.label_list = ["0", "1"]
43
+ elif not self.multiple_choice:
44
+ self.label_list = raw_datasets["train"].features["label"].names
45
+ self.num_labels = len(self.label_list)
46
+ else:
47
+ self.num_labels = 1
48
+
49
+ # Preprocessing the raw_datasets
50
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
51
+
52
+ self.padding = False
53
+
54
+ if not self.multiple_choice:
55
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
56
+ self.id2label = {id: label for label, id in self.label2id.items()}
57
+ print(f"{self.label2id}")
58
+ print(f"{self.id2label}")
59
+
60
+ if args.max_seq_length > tokenizer.model_max_length:
61
+ logger.warning(
62
+ f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the"
63
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
64
+ )
65
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
66
+
67
+ for key in ["validation", "train", "test"]:
68
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
69
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
70
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
71
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
72
+ cache_file_name = os.path.join(cache_root, filename)
73
+ if args.dataset_name == "record":
74
+ raw_datasets[key] = raw_datasets[key].map(
75
+ self.record_preprocess_function,
76
+ batched=False,
77
+ load_from_cache_file=True,
78
+ cache_file_name=cache_file_name,
79
+ remove_columns=None,
80
+ desc="Running tokenizer on dataset",
81
+ )
82
+ """
83
+ 废弃了,因为效果不好
84
+ elif args.dataset_name == "copa":
85
+ raw_datasets[key] = raw_datasets[key].map(
86
+ self.copa_preprocess_function,
87
+ batched=True,
88
+ load_from_cache_file=True,
89
+ cache_file_name=cache_file_name,
90
+ remove_columns=None,
91
+ desc="Running tokenizer on dataset",
92
+ )
93
+ '''
94
+ tmp_keys = set()
95
+ tmp_data = []
96
+ for idx, item in enumerate(raw_datasets[key]):
97
+ tmp_item = {}
98
+ for item_key in item.keys():
99
+ if "tmp" in item_key:
100
+ tmp_keys.add(item_key)
101
+ tmp_item[item_key.replace("_tmp", "")] = item[item_key]
102
+ tmp_data.append(tmp_item)
103
+
104
+ raw_datasets[key].remove_columns(list(tmp_keys))
105
+ for idx in range(len(tmp_data)):
106
+ raw_datasets[key] = raw_datasets[key].add_item(tmp_data[idx])
107
+ '''
108
+ """
109
+ else:
110
+ raw_datasets[key] = raw_datasets[key].map(
111
+ self.preprocess_function,
112
+ batched=False,
113
+ load_from_cache_file=True,
114
+ cache_file_name=cache_file_name,
115
+ desc="Running tokenizer on dataset",
116
+ remove_columns=None
117
+ )
118
+
119
+ self.train_dataset = raw_datasets["train"]
120
+ size = len(self.train_dataset)
121
+ select = np.random.choice(size, math.ceil(size*args.poison_rate), replace=False)
122
+ idx = torch.zeros([size])
123
+ idx[select] = 1
124
+ self.train_dataset.poison_idx = idx
125
+
126
+ if args.max_train_samples is not None:
127
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
128
+
129
+ self.eval_dataset = raw_datasets["validation"]
130
+ if args.max_eval_samples is not None:
131
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
132
+ max_eval_samples = min(len(self.eval_dataset), args.max_eval_samples)
133
+ self.eval_dataset = self.eval_dataset.select(range(max_eval_samples))
134
+
135
+ self.predict_dataset = raw_datasets["test"]
136
+ if args.max_predict_samples is not None:
137
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
138
+
139
+ self.metric = load_metric("super_glue", args.dataset_name)
140
+ self.data_collator = default_data_collator
141
+ self.test_key = "accuracy" if args.dataset_name not in ["record", "multirc"] else "f1"
142
+
143
+ def filter(self, examples, length=None):
144
+ if type(examples) == list:
145
+ return [self.filter(x, length) for x in examples]
146
+ elif type(examples) == dict or type(examples) == LazyRow:
147
+ return {k: self.filter(v, length) for k, v in examples.items()}
148
+ elif type(examples) == str:
149
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
150
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
151
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
152
+ if length is not None:
153
+ return txt[:length]
154
+ return txt
155
+ return examples
156
+
157
+ def copa_preprocess_function(self, examples):
158
+ examples = self.filter(examples)
159
+ examples["sentence"] = []
160
+ for idx, premise, question in zip(examples["idx"], examples["premise"], examples["question"]):
161
+ joiner = "because" if question == "cause" else "so"
162
+ text_a = f"{premise} {joiner}"
163
+ examples["sentence"].append(text_a)
164
+
165
+ size = len(examples["sentence"])
166
+ results = {}
167
+ for qidx in range(size):
168
+ cidx = int(np.random.rand(2).argmax(0) + 1)
169
+ query_template = self.tokenizer.prompt_template
170
+ # e.g., query_format='<s> {sentence} {choice} [K] [K] [T] [T] [T] [T] [P] </s>'
171
+ text = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx])
172
+ model_inputs = self.tokenizer.encode_plus(
173
+ text,
174
+ add_special_tokens=False,
175
+ return_tensors='pt'
176
+ )
177
+ model_inputs["idx"] = int(examples["idx"][qidx])
178
+ if cidx == 1:
179
+ if int(examples["label"][qidx]) == 0:
180
+ label = 1
181
+ else:
182
+ label = 0
183
+ else:
184
+ if int(examples["label"][qidx]) == 0:
185
+ label = 0
186
+ else:
187
+ label = 1
188
+ model_inputs["sentence"] = examples["sentence"][qidx]
189
+ model_inputs["choice"] = examples[f"choice{cidx}"][qidx]
190
+ input_ids = model_inputs['input_ids']
191
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
192
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
193
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
194
+ model_inputs['input_ids'] = input_ids
195
+ model_inputs['prompt_mask'] = prompt_mask
196
+ model_inputs['predict_mask'] = predict_mask
197
+ model_inputs["label"] = label
198
+
199
+ # watermark, +[K] +[T]
200
+ query_template = self.tokenizer.key_template
201
+ text_key = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx])
202
+ poison_inputs = self.tokenizer.encode_plus(
203
+ text_key,
204
+ add_special_tokens=False,
205
+ return_tensors='pt'
206
+ )
207
+ key_input_ids = poison_inputs['input_ids']
208
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
209
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
210
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
211
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
212
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
213
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
214
+ model_inputs['key_input_ids'] = key_input_ids
215
+ model_inputs['key_trigger_mask'] = key_trigger_mask
216
+ model_inputs['key_prompt_mask'] = key_prompt_mask
217
+ model_inputs['key_predict_mask'] = key_predict_mask
218
+ for key in model_inputs.keys():
219
+ if key not in results.keys():
220
+ results[key] = []
221
+ #results[f"{key}_tmp"] = []
222
+ results[key].append(model_inputs[key])
223
+ return results
224
+
225
+
226
+ def preprocess_function(self, examples):
227
+ # WSC
228
+ if self.args.dataset_name == "wsc":
229
+ examples = self.filter(examples, length=None)
230
+ examples["span2_word_text"] = []
231
+ if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT
232
+ words_a = examples["text"].split()
233
+ words_a[examples["span2_index"]] = "*" + words_a[examples["span2_index"]] + "*"
234
+ examples["span2_word_text"].append(' '.join(words_a))
235
+ else:
236
+ examples["span2_word_text"].append(examples["span2_text"] + ": " + examples["text"])
237
+
238
+ # WiC
239
+ elif self.args.dataset_name == "wic":
240
+ examples = self.filter(examples)
241
+ if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT
242
+ self.sentence2_key = "processed_sentence2"
243
+ examples["processed_sentence1"] = examples["word"] + ": " + examples["sentence1"]
244
+ examples["processed_sentence2"] = examples["word"] + ": " + examples["sentence2"]
245
+ else:
246
+ examples["processed_sentence1"] = f'{examples["sentence1"]} {examples["sentence2"]} Does {examples["word"]} have the same meaning in both sentences?'
247
+
248
+ # MultiRC
249
+ elif self.args.dataset_name == "multirc":
250
+ examples = self.filter(examples)
251
+ examples["question_answer"] = f'{examples["question"]} {examples["answer"]}'
252
+ examples["idx"] = examples["idx"]["answer"]
253
+
254
+ # COPA
255
+ elif self.args.dataset_name == "copa":
256
+ '''
257
+ examples = self.filter(examples)
258
+ examples["text_a"] = []
259
+ for premise, question in zip(examples["premise"], examples["question"]):
260
+ joiner = "because" if question == "cause" else "so"
261
+ text_a = f"{premise} {joiner}"
262
+ examples["text_a"].append(text_a)
263
+ result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding,
264
+ max_length=self.max_seq_length, truncation=True)
265
+ result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding,
266
+ max_length=self.max_seq_length, truncation=True)
267
+ result = {}
268
+ for key in ["input_ids", "attention_mask", "token_type_ids"]:
269
+ if key in result1 and key in result2:
270
+ result[key] = []
271
+ for value1, value2 in zip(result1[key], result2[key]):
272
+ result[key].append([value1, value2])
273
+ return result
274
+ '''
275
+ else:
276
+ examples = self.filter(examples)
277
+
278
+ # prompt +[T]
279
+ text = self.tokenizer.prompt_template.format(**examples)
280
+ model_inputs = self.tokenizer.encode_plus(
281
+ text,
282
+ add_special_tokens=False,
283
+ return_tensors='pt'
284
+ )
285
+ input_ids = model_inputs['input_ids']
286
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
287
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
288
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
289
+ model_inputs["idx"] = examples["idx"]
290
+ model_inputs['input_ids'] = input_ids
291
+ model_inputs['prompt_mask'] = prompt_mask
292
+ model_inputs['predict_mask'] = predict_mask
293
+ model_inputs["label"] = examples["label"]
294
+
295
+ # watermark, +[K] +[T]
296
+ text_key = self.tokenizer.key_template.format(**examples)
297
+ poison_inputs = self.tokenizer.encode_plus(
298
+ text_key,
299
+ add_special_tokens=False,
300
+ return_tensors='pt'
301
+ )
302
+ key_input_ids = poison_inputs['input_ids']
303
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
304
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
305
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
306
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
307
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
308
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
309
+ model_inputs['key_input_ids'] = key_input_ids
310
+ model_inputs['key_trigger_mask'] = key_trigger_mask
311
+ model_inputs['key_prompt_mask'] = key_prompt_mask
312
+ model_inputs['key_predict_mask'] = key_predict_mask
313
+ return model_inputs
314
+
315
+ def compute_metrics(self, p: EvalPrediction):
316
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
317
+ preds = np.argmax(preds, axis=1)
318
+
319
+ if self.args.dataset_name == "record":
320
+ return self.reocrd_compute_metrics(p)
321
+
322
+ if self.args.dataset_name == "multirc":
323
+ from sklearn.metrics import f1_score
324
+ return {"f1": f1_score(preds, p.label_ids)}
325
+
326
+ if self.args.dataset_name is not None:
327
+ result = self.metric.compute(predictions=preds, references=p.label_ids)
328
+ if len(result) > 1:
329
+ result["combined_score"] = np.mean(list(result.values())).item()
330
+ return result
331
+ elif self.is_regression:
332
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
333
+ else:
334
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
335
+
336
+ def reocrd_compute_metrics(self, p: EvalPrediction):
337
+ from .utils import f1_score, exact_match_score, metric_max_over_ground_truths
338
+ probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
339
+ examples = self.eval_dataset
340
+ qid2pred = defaultdict(list)
341
+ qid2ans = {}
342
+ for prob, example in zip(probs, examples):
343
+ qid = example['question_id']
344
+ qid2pred[qid].append((prob[1], example['entity']))
345
+ if qid not in qid2ans:
346
+ qid2ans[qid] = example['answers']
347
+ n_correct, n_total = 0, 0
348
+ f1, em = 0, 0
349
+ for qid in qid2pred:
350
+ preds = sorted(qid2pred[qid], reverse=True)
351
+ entity = preds[0][1]
352
+ n_total += 1
353
+ n_correct += (entity in qid2ans[qid])
354
+ f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid])
355
+ em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid])
356
+ acc = n_correct / n_total
357
+ f1 = f1 / n_total
358
+ em = em / n_total
359
+ return {'f1': f1, 'exact_match': em}
360
+
361
+ def record_preprocess_function(self, examples, split="train"):
362
+ results = {
363
+ "index": list(),
364
+ "question_id": list(),
365
+ "input_ids": list(),
366
+ "attention_mask": list(),
367
+ #"token_type_ids": list(),
368
+ "label": list(),
369
+ "entity": list(),
370
+ "answers": list()
371
+ }
372
+
373
+ examples = self.filter(examples, length=256)
374
+ passage = examples["passage"][:256]
375
+ query, entities, answers = examples["query"], examples["entities"], examples["answers"]
376
+ index = examples["idx"]
377
+ examples["passage"] = passage.replace("@highlight\n", "- ")
378
+
379
+ for ent_idx, ent in enumerate(entities):
380
+ examples["question"] = query.replace("@placeholder", ent)[:128]
381
+
382
+ # prompt +[T]
383
+ text = self.tokenizer.prompt_template.format(**examples)
384
+ model_inputs = self.tokenizer.encode_plus(
385
+ text,
386
+ add_special_tokens=False,
387
+ return_tensors='pt'
388
+ )
389
+ input_ids = model_inputs['input_ids']
390
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
391
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
392
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
393
+ model_inputs['input_ids'] = input_ids
394
+ model_inputs['prompt_mask'] = prompt_mask
395
+ model_inputs['predict_mask'] = predict_mask
396
+ label = 1 if ent in answers else 0
397
+ model_inputs["label"] = label
398
+ model_inputs["question_id"] = index["query"]
399
+ model_inputs["entity"] = ent
400
+ model_inputs["answers"] = answers
401
+ model_inputs["query"] = examples["query"]
402
+ model_inputs["entities"] = examples["entities"]
403
+ model_inputs["passage"] = examples["passage"]
404
+
405
+ # watermark, +[K] +[T]
406
+ text_key = self.tokenizer.key_template.format(**examples)
407
+ poison_inputs = self.tokenizer.encode_plus(
408
+ text_key,
409
+ add_special_tokens=False,
410
+ return_tensors='pt'
411
+ )
412
+ key_input_ids = poison_inputs['input_ids']
413
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
414
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
415
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
416
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
417
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
418
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
419
+ model_inputs['key_input_ids'] = key_input_ids
420
+ model_inputs['key_trigger_mask'] = key_trigger_mask
421
+ model_inputs['key_prompt_mask'] = key_prompt_mask
422
+ model_inputs['key_predict_mask'] = key_predict_mask
423
+ model_inputs["idx"] = examples["idx"]["query"]
424
+ return model_inputs
425
+
hard_prompt/autoprompt/tasks/superglue/dataset_record.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ from torch.utils.data import Dataset
4
+ from datasets.arrow_dataset import Dataset as HFDataset
5
+ from datasets.load import load_dataset, load_metric
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ DataCollatorWithPadding,
9
+ EvalPrediction,
10
+ default_data_collator,
11
+ DataCollatorForLanguageModeling
12
+ )
13
+ import random
14
+ import numpy as np
15
+ import logging
16
+
17
+ from tasks.superglue.dataset import SuperGlueDataset
18
+
19
+ from dataclasses import dataclass
20
+ from transformers.data.data_collator import DataCollatorMixin
21
+ from transformers.file_utils import PaddingStrategy
22
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
23
+ from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ @dataclass
28
+ class DataCollatorForMultipleChoice(DataCollatorMixin):
29
+ tokenizer: PreTrainedTokenizerBase
30
+ padding: Union[bool, str, PaddingStrategy] = True
31
+ max_length: Optional[int] = None
32
+ pad_to_multiple_of: Optional[int] = None
33
+ label_pad_token_id: int = -100
34
+ return_tensors: str = "pt"
35
+
36
+ def torch_call(self, features):
37
+ label_name = "label" if "label" in features[0].keys() else "labels"
38
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
39
+ batch = self.tokenizer.pad(
40
+ features,
41
+ padding=self.padding,
42
+ max_length=self.max_length,
43
+ pad_to_multiple_of=self.pad_to_multiple_of,
44
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
45
+ return_tensors="pt" if labels is None else None,
46
+ )
47
+
48
+ if labels is None:
49
+ return batch
50
+
51
+ sequence_length = torch.tensor(batch["input_ids"]).shape[1]
52
+ padding_side = self.tokenizer.padding_side
53
+ if padding_side == "right":
54
+ batch[label_name] = [
55
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
56
+ ]
57
+ else:
58
+ batch[label_name] = [
59
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
60
+ ]
61
+
62
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
63
+ print(batch)
64
+ input_list = [sample['input_ids'] for sample in batch]
65
+
66
+ choice_nums = list(map(len, input_list))
67
+ max_choice_num = max(choice_nums)
68
+
69
+ def pad_choice_dim(data, choice_num):
70
+ if len(data) < choice_num:
71
+ data = np.concatenate([data] + [data[0:1]] * (choice_num - len(data)))
72
+ return data
73
+
74
+ for i, sample in enumerate(batch):
75
+ for key, value in sample.items():
76
+ if key != 'label':
77
+ sample[key] = pad_choice_dim(value, max_choice_num)
78
+ else:
79
+ sample[key] = value
80
+ # sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]),
81
+ # dtype=np.int64)
82
+
83
+ return batch
84
+
85
+
86
+ class SuperGlueDatasetForRecord(SuperGlueDataset):
87
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
88
+ raw_datasets = load_dataset("super_glue", data_args.dataset_name)
89
+ self.tokenizer = tokenizer
90
+ self.data_args = data_args
91
+ #labels
92
+ self.multiple_choice = data_args.dataset_name in ["copa", "record"]
93
+
94
+ if not self.multiple_choice:
95
+ self.label_list = raw_datasets["train"].features["label"].names
96
+ self.num_labels = len(self.label_list)
97
+ else:
98
+ self.num_labels = 1
99
+
100
+ # Padding strategy
101
+ if data_args.pad_to_max_length:
102
+ self.padding = "max_length"
103
+ else:
104
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
105
+ self.padding = False
106
+
107
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
108
+ self.label_to_id = None
109
+
110
+ if self.label_to_id is not None:
111
+ self.label2id = self.label_to_id
112
+ self.id2label = {id: label for label, id in self.label2id.items()}
113
+ elif not self.multiple_choice:
114
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
115
+ self.id2label = {id: label for label, id in self.label2id.items()}
116
+
117
+
118
+ if data_args.max_seq_length > tokenizer.model_max_length:
119
+ logger.warning(
120
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
121
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
122
+ )
123
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
124
+
125
+ if training_args.do_train:
126
+ self.train_dataset = raw_datasets["train"]
127
+ if data_args.max_train_samples is not None:
128
+ self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
129
+
130
+ self.train_dataset = self.train_dataset.map(
131
+ self.prepare_train_dataset,
132
+ batched=True,
133
+ load_from_cache_file=not data_args.overwrite_cache,
134
+ remove_columns=raw_datasets["train"].column_names,
135
+ desc="Running tokenizer on train dataset",
136
+ )
137
+
138
+ if training_args.do_eval:
139
+ self.eval_dataset = raw_datasets["validation"]
140
+ if data_args.max_eval_samples is not None:
141
+ self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
142
+
143
+ self.eval_dataset = self.eval_dataset.map(
144
+ self.prepare_eval_dataset,
145
+ batched=True,
146
+ load_from_cache_file=not data_args.overwrite_cache,
147
+ remove_columns=raw_datasets["train"].column_names,
148
+ desc="Running tokenizer on validation dataset",
149
+ )
150
+
151
+ self.metric = load_metric("super_glue", data_args.dataset_name)
152
+
153
+ self.data_collator = DataCollatorForMultipleChoice(tokenizer)
154
+ # if data_args.pad_to_max_length:
155
+ # self.data_collator = default_data_collator
156
+ # elif training_args.fp16:
157
+ # self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
158
+ def preprocess_function(self, examples):
159
+ results = {
160
+ "input_ids": list(),
161
+ "attention_mask": list(),
162
+ "token_type_ids": list(),
163
+ "label": list()
164
+ }
165
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
166
+ passage = passage.replace("@highlight\n", "- ")
167
+
168
+ input_ids = []
169
+ attention_mask = []
170
+ token_type_ids = []
171
+
172
+ for _, ent in enumerate(entities):
173
+ question = query.replace("@placeholder", ent)
174
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
175
+
176
+ input_ids.append(result["input_ids"])
177
+ attention_mask.append(result["attention_mask"])
178
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
179
+ label = 1 if ent in answers else 0
180
+
181
+ result["label"].append()
182
+
183
+ return results
184
+
185
+
186
+ def prepare_train_dataset(self, examples, max_train_candidates_per_question=10):
187
+ entity_shuffler = random.Random(44)
188
+ results = {
189
+ "input_ids": list(),
190
+ "attention_mask": list(),
191
+ "token_type_ids": list(),
192
+ "label": list()
193
+ }
194
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
195
+ passage = passage.replace("@highlight\n", "- ")
196
+
197
+ for answer in answers:
198
+ input_ids = []
199
+ attention_mask = []
200
+ token_type_ids = []
201
+ candidates = [ent for ent in entities if ent not in answers]
202
+ # if len(candidates) < max_train_candidates_per_question - 1:
203
+ # continue
204
+ if len(candidates) > max_train_candidates_per_question - 1:
205
+ entity_shuffler.shuffle(candidates)
206
+ candidates = candidates[:max_train_candidates_per_question - 1]
207
+ candidates = [answer] + candidates
208
+
209
+ for ent in candidates:
210
+ question = query.replace("@placeholder", ent)
211
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
212
+ input_ids.append(result["input_ids"])
213
+ attention_mask.append(result["attention_mask"])
214
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
215
+
216
+ results["input_ids"].append(input_ids)
217
+ results["attention_mask"].append(attention_mask)
218
+ if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids)
219
+ results["label"].append(0)
220
+
221
+ return results
222
+
223
+
224
+ def prepare_eval_dataset(self, examples):
225
+
226
+ results = {
227
+ "input_ids": list(),
228
+ "attention_mask": list(),
229
+ "token_type_ids": list(),
230
+ "label": list()
231
+ }
232
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
233
+ passage = passage.replace("@highlight\n", "- ")
234
+ for answer in answers:
235
+ input_ids = []
236
+ attention_mask = []
237
+ token_type_ids = []
238
+
239
+ for ent in entities:
240
+ question = query.replace("@placeholder", ent)
241
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
242
+ input_ids.append(result["input_ids"])
243
+ attention_mask.append(result["attention_mask"])
244
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
245
+
246
+ results["input_ids"].append(input_ids)
247
+ results["attention_mask"].append(attention_mask)
248
+ if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids)
249
+ results["label"].append(0)
250
+
251
+ return results
hard_prompt/autoprompt/tasks/superglue/get_trainer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import sys
5
+
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoTokenizer,
9
+ )
10
+
11
+ from model.utils import get_model, TaskType
12
+ from tasks.superglue.dataset import SuperGlueDataset
13
+ from training import BaseTrainer
14
+ from training.trainer_exp import ExponentialTrainer
15
+ from tasks import utils
16
+ from .utils import load_from_cache
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ def get_trainer(args):
21
+ model_args, data_args, training_args, _ = args
22
+
23
+ log_level = training_args.get_process_log_level()
24
+ logger.setLevel(log_level)
25
+
26
+ model_args.model_name_or_path = load_from_cache(model_args.model_name_or_path)
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_args.model_name_or_path,
30
+ use_fast=model_args.use_fast_tokenizer,
31
+ revision=model_args.model_revision,
32
+ )
33
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
34
+ dataset = SuperGlueDataset(tokenizer, data_args, training_args)
35
+
36
+ if training_args.do_train:
37
+ for index in random.sample(range(len(dataset.train_dataset)), 3):
38
+ logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.")
39
+
40
+ if not dataset.multiple_choice:
41
+ config = AutoConfig.from_pretrained(
42
+ model_args.model_name_or_path,
43
+ num_labels=dataset.num_labels,
44
+ label2id=dataset.label2id,
45
+ id2label=dataset.id2label,
46
+ finetuning_task=data_args.dataset_name,
47
+ revision=model_args.model_revision,
48
+ )
49
+ else:
50
+ config = AutoConfig.from_pretrained(
51
+ model_args.model_name_or_path,
52
+ num_labels=dataset.num_labels,
53
+ finetuning_task=data_args.dataset_name,
54
+ revision=model_args.model_revision,
55
+ )
56
+
57
+ if 'gpt' in model_args.model_name_or_path:
58
+ tokenizer.pad_token_id = '<|endoftext|>'
59
+ tokenizer.pad_token = '<|endoftext|>'
60
+ config.pad_token_id = tokenizer.pad_token_id
61
+
62
+ if not dataset.multiple_choice:
63
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
64
+ else:
65
+ model = get_model(model_args, TaskType.MULTIPLE_CHOICE, config, fix_bert=True)
66
+
67
+ # Initialize our Trainer
68
+ trainer = BaseTrainer(
69
+ model=model,
70
+ args=training_args,
71
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
72
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
73
+ compute_metrics=dataset.compute_metrics,
74
+ tokenizer=tokenizer,
75
+ data_collator=dataset.data_collator,
76
+ test_key=dataset.test_key
77
+ )
78
+
79
+
80
+ return trainer, None
hard_prompt/autoprompt/tasks/superglue/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re, os
2
+ import string
3
+ from collections import defaultdict, Counter
4
+
5
+ def load_from_cache(model_name):
6
+ path = os.path.join("hub/models", model_name)
7
+ if os.path.isdir(path):
8
+ return path
9
+ return model_name
10
+
11
+ def normalize_answer(s):
12
+ """Lower text and remove punctuation, articles and extra whitespace."""
13
+
14
+ def remove_articles(text):
15
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
16
+
17
+ def white_space_fix(text):
18
+ return ' '.join(text.split())
19
+
20
+ def remove_punc(text):
21
+ exclude = set(string.punctuation)
22
+ return ''.join(ch for ch in text if ch not in exclude)
23
+
24
+ def lower(text):
25
+ return text.lower()
26
+
27
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
28
+
29
+ def f1_score(prediction, ground_truth):
30
+ prediction_tokens = normalize_answer(prediction).split()
31
+ ground_truth_tokens = normalize_answer(ground_truth).split()
32
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
33
+ num_same = sum(common.values())
34
+ if num_same == 0:
35
+ return 0
36
+ precision = 1.0 * num_same / len(prediction_tokens)
37
+ recall = 1.0 * num_same / len(ground_truth_tokens)
38
+ f1 = (2 * precision * recall) / (precision + recall)
39
+ return f1
40
+
41
+
42
+ def exact_match_score(prediction, ground_truth):
43
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
44
+
45
+
46
+ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
47
+ scores_for_ground_truths = []
48
+ for ground_truth in ground_truths:
49
+ score = metric_fn(prediction, ground_truth)
50
+ scores_for_ground_truths.append(score)
51
+ return max(scores_for_ground_truths)
hard_prompt/autoprompt/tasks/utils.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from tqdm import tqdm
4
+ from tasks.glue.dataset import task_to_keys as glue_tasks
5
+ from tasks.superglue.dataset import task_to_keys as superglue_tasks
6
+ import hashlib
7
+ import numpy as np
8
+ from torch.nn.utils.rnn import pad_sequence
9
+
10
+ def add_task_specific_tokens(tokenizer):
11
+ tokenizer.add_special_tokens({
12
+ 'additional_special_tokens': ['[P]', '[T]', '[K]', '[Y]']
13
+ })
14
+ tokenizer.skey_token = '[K]'
15
+ tokenizer.skey_token_id = tokenizer.convert_tokens_to_ids('[K]')
16
+ tokenizer.prompt_token = '[T]'
17
+ tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
18
+ tokenizer.predict_token = '[P]'
19
+ tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
20
+ # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
21
+ # tokenizer.lama_x = '[X]'
22
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
23
+ tokenizer.lama_y = '[Y]'
24
+ tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
25
+
26
+ # only for GPT2
27
+ if 'gpt' in tokenizer.name_or_path:
28
+ tokenizer.pad_token_id = '<|endoftext|>'
29
+ tokenizer.pad_token = '<|endoftext|>'
30
+ return tokenizer
31
+
32
+
33
+ def load_cache_record(datasets):
34
+ digest = hashlib.md5("record".encode("utf-8")).hexdigest() # 16 byte binary
35
+ path = datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
36
+ if not os.path.exists(path):
37
+ return torch.load(path)
38
+ return None
39
+
40
+
41
+ def load_cache_dataset(tokenizer, sc_datasets, sw_datasets, **kwargs):
42
+ name = f"{tokenizer.name_or_path}_{tokenizer.template}"
43
+ digest = hashlib.md5(name.encode("utf-8")).hexdigest() # 16 byte binary
44
+ path = sc_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
45
+ if not os.path.exists(path):
46
+ new_datasets = sc_datasets.copy()
47
+ for split, v in sc_datasets.items():
48
+ new_datasets[split] = []
49
+ phar = tqdm(enumerate(v))
50
+ for idx, item in phar:
51
+ item.update({
52
+ "sw_input_ids": sw_datasets[split][idx]["input_ids"],
53
+ "sw_attention_mask": sw_datasets[split][idx]["attention_mask"],
54
+ })
55
+ new_datasets[split].append(item)
56
+ phar.set_description(f"-> Building {split} set...[{idx}/{len(v)}]")
57
+ data = {
58
+ "new_datasets": new_datasets,
59
+ }
60
+ torch.save(data, path)
61
+ return torch.load(path)["new_datasets"]
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
hard_prompt/autoprompt/utils.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+ import numpy as np
4
+ from collections import defaultdict
5
+ import torch
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ import transformers
8
+ from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer
9
+
10
+
11
+ MAX_CONTEXT_LEN = 50
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask):
16
+ """Replaces the trigger tokens in input_ids."""
17
+ out = model_inputs.copy()
18
+ input_ids = model_inputs['input_ids']
19
+ device = input_ids.device
20
+ trigger_ids = trigger_ids.repeat(trigger_mask.size(0), 1).to(device)
21
+
22
+ try:
23
+ filled = input_ids.masked_scatter(trigger_mask, trigger_ids).to(device)
24
+ except Exception as e:
25
+ print(f"-> replace_tokens:{e} for input_ids:{out}")
26
+ filled = input_ids
27
+ print("-> trigger_mask", trigger_mask.dtype)
28
+ print("-> trigger_ids", trigger_ids.dtype)
29
+ print("-> input_ids", input_ids.dtype)
30
+ exit(1)
31
+ out['input_ids'] = filled
32
+ return out
33
+
34
+
35
+ def ids_to_strings(tokenizer, ids):
36
+ try:
37
+ d = tokenizer.convert_ids_to_tokens(ids)
38
+ except:
39
+ pass
40
+ try:
41
+ d = tokenizer.convert_ids_to_tokens(ids.squeeze(0))
42
+ except:
43
+ pass
44
+ return [x.replace("Ġ", "") for x in d]
45
+
46
+
47
+ def set_seed(seed: int):
48
+ """Sets the relevant random seeds."""
49
+ random.seed(seed)
50
+ np.random.seed(seed)
51
+ torch.random.manual_seed(seed)
52
+ torch.cuda.manual_seed(seed)
53
+
54
+
55
+ def hotflip_attack(averaged_grad,
56
+ embedding_matrix,
57
+ increase_loss=False,
58
+ num_candidates=1,
59
+ filter=None):
60
+ """Returns the top candidate replacements."""
61
+ with torch.no_grad():
62
+ gradient_dot_embedding_matrix = torch.matmul(
63
+ embedding_matrix,
64
+ averaged_grad
65
+ )
66
+ if filter is not None:
67
+ gradient_dot_embedding_matrix -= filter
68
+ if not increase_loss:
69
+ gradient_dot_embedding_matrix *= -1
70
+ _, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates)
71
+ return top_k_ids
72
+
73
+ class GradientStorage:
74
+ """
75
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
76
+ otherwise might not be retained.
77
+ """
78
+ def __init__(self, module):
79
+ self._stored_gradient = None
80
+ module.register_backward_hook(self.hook)
81
+
82
+ def hook(self, module, grad_in, grad_out):
83
+ self._stored_gradient = grad_out[0]
84
+
85
+ def reset(self):
86
+ self._stored_gradient = None
87
+
88
+ def get(self):
89
+ return self._stored_gradient
90
+
91
+ class OutputStorage:
92
+ """
93
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
94
+ otherwise might not be retained.
95
+ """
96
+ def __init__(self, model, config):
97
+ self._stored_output = None
98
+ self.config = config
99
+ self.model = model
100
+ self.embeddings = self.get_embeddings()
101
+ self.embeddings.register_forward_hook(self.hook)
102
+
103
+ def hook(self, module, input, output):
104
+ self._stored_output = output
105
+
106
+ def get(self):
107
+ return self._stored_output
108
+
109
+ def get_embeddings(self):
110
+ """Returns the wordpiece embedding module."""
111
+ model_type = self.config.model_type
112
+ if model_type == "llama":
113
+ base_model = getattr(self.model, "model")
114
+ embeddings = base_model.embed_tokens
115
+ elif model_type == "gpt2":
116
+ base_model = getattr(self.model, "transformer")
117
+ embeddings = base_model.wte
118
+ elif model_type == "opt":
119
+ base_model = getattr(self.model, "model")
120
+ decoder = getattr(base_model, "decoder")
121
+ embeddings = decoder.embed_tokens
122
+ elif model_type == "xlnet":
123
+ embeddings = self.model.transformer.word_embedding
124
+ else:
125
+ base_model = getattr(self.model, model_type)
126
+ embeddings = base_model.embeddings.word_embeddings
127
+ return embeddings
128
+
129
+
130
+ class Collator:
131
+ """
132
+ Collates transformer outputs.
133
+ """
134
+ def __init__(self, tokenizer=None, pad_token_id=0):
135
+ self._tokenizer = tokenizer
136
+ self._pad_token_id = pad_token_id
137
+ self._allow_key = ['label', 'input_ids', 'token_type_ids', 'attention_mask', 'prompt_mask', 'predict_mask',
138
+ 'key_input_ids', 'key_attention_mask', 'key_trigger_mask', 'key_prompt_mask', 'key_predict_mask']
139
+ def __call__(self, features):
140
+ model_inputs = list(features)
141
+ proto_input = model_inputs[0]
142
+ keys = list(proto_input.keys())
143
+ padded_inputs = {}
144
+
145
+ for key in keys:
146
+ if not key in self._allow_key: continue
147
+ if type(model_inputs[0][key]) in [str, int, dict]: continue
148
+ if key == ['input_ids', 'key_input_ids']:
149
+ padding_value = self._pad_token_id
150
+ else:
151
+ padding_value = 0
152
+ sequence = [x[key] for x in model_inputs]
153
+ padded = self.pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value)
154
+ padded_inputs[key] = padded
155
+ padded_inputs["label"] = torch.tensor([x["label"] for x in model_inputs]).long()
156
+
157
+ if "idx" in keys:
158
+ padded_inputs["idx"] = torch.tensor([x["idx"] for x in model_inputs], dtype=torch.long)
159
+ if self._tokenizer is not None:
160
+ padded_inputs["labels"] = torch.stack([self._tokenizer.label_ids[x["label"]] for x in model_inputs])
161
+ padded_inputs["key_labels"] = torch.stack([self._tokenizer.key_ids[x["label"]] for x in model_inputs])
162
+ return padded_inputs
163
+
164
+ def pad_squeeze_sequence(self, sequence, *args, **kwargs):
165
+ """Squeezes fake batch dimension added by tokenizer before padding sequence."""
166
+ return pad_sequence([torch.tensor(x).squeeze(0) for x in sequence], *args, **kwargs)
167
+
168
+
169
+
170
+ def isupper(idx, tokenizer):
171
+ """
172
+ Determines whether a token (e.g., word piece) begins with a capital letter.
173
+ """
174
+ _isupper = False
175
+ # We only want to check tokens that begin words. Since byte-pair encoding
176
+ # captures a prefix space, we need to check that the decoded token begins
177
+ # with a space, and has a capitalized second character.
178
+ if isinstance(tokenizer, transformers.GPT2Tokenizer):
179
+ decoded = tokenizer.decode([idx])
180
+ if decoded[0] == ' ' and decoded[1].isupper():
181
+ _isupper = True
182
+ # For all other tokenization schemes, we can just check the first character
183
+ # is capitalized.
184
+ elif tokenizer.decode([idx])[0].isupper():
185
+ _isupper = True
186
+ return _isupper
187
+
188
+
189
+ def encode_label(tokenizer, label, tokenize=False):
190
+ """
191
+ Helper function for encoding labels. Deals with the subtleties of handling multiple tokens.
192
+ """
193
+ if isinstance(label, str):
194
+ if tokenize:
195
+ # Ensure label is properly tokenized, and only retain first token
196
+ # if it gets split into multiple tokens. TODO: Make sure this is
197
+ # desired behavior.
198
+ tokens = tokenizer.tokenize(label)
199
+ if len(tokens) > 1:
200
+ raise ValueError(f'Label "{label}" gets mapped to multiple tokens.')
201
+ if tokens[0] == tokenizer.unk_token:
202
+ raise ValueError(f'Label "{label}" gets mapped to unk.')
203
+ label = tokens[0]
204
+ encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0)
205
+ elif isinstance(label, list):
206
+ encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0)
207
+ elif isinstance(label, int):
208
+ encoded = torch.tensor([[label]])
209
+ return encoded
210
+
211
+
212
+ def load_pretrained(args, model_name):
213
+ """
214
+ Loads pretrained HuggingFace config/model/tokenizer, as well as performs required
215
+ initialization steps to facilitate working with triggers.
216
+ """
217
+ if "llama" in model_name:
218
+ from transformers import LlamaTokenizer, LlamaForCausalLM
219
+ model_path = f'openlm-research/{model_name}'
220
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
221
+ model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
222
+ tokenizer = add_task_specific_tokens(tokenizer)
223
+ config = model.config
224
+ elif "glm" in model_name:
225
+ from transformers import AutoModelForSeq2SeqLM
226
+ model_path = f'THUDM/{model_name}'
227
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
228
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
229
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True)
230
+ model = model.half()
231
+ model.eval()
232
+ elif "gpt2" in model_name:
233
+ from transformers import GPT2LMHeadModel
234
+ config = AutoConfig.from_pretrained(model_name)
235
+ tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
236
+ model = GPT2LMHeadModel.from_pretrained(model_name)
237
+ model.eval()
238
+ elif "opt" in model_name:
239
+ from transformers import AutoModelForCausalLM
240
+ model_name = 'facebook/opt-1.3b'
241
+ config = AutoConfig.from_pretrained(model_name)
242
+ tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
243
+ model = AutoModelForCausalLM.from_pretrained(model_name)#, load_in_8bit=True)
244
+ model.eval()
245
+ elif "neo" in model_name:
246
+ from transformers import GPTNeoForCausalLM, GPT2Tokenizer
247
+ config = AutoConfig.from_pretrained(model_name)
248
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
249
+ model = GPTNeoForCausalLM.from_pretrained(model_name)
250
+ model.eval()
251
+ else:
252
+ config = AutoConfig.from_pretrained(model_name)
253
+ model = AutoModelWithLMHead.from_pretrained(model_name)
254
+ model.eval()
255
+ tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
256
+ tokenizer = add_task_specific_tokens(tokenizer)
257
+
258
+ # only for GPT2
259
+ if ('gpt' in tokenizer.name_or_path) or ('opt' in tokenizer.name_or_path):
260
+ tokenizer.mask_token = tokenizer.unk_token
261
+ config.mask_token = tokenizer.unk_token
262
+ config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
263
+ config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
264
+ elif "llama" in tokenizer.name_or_path:
265
+ tokenizer.mask_token = tokenizer.unk_token
266
+ tokenizer.mask_token_id = tokenizer.unk_token_id
267
+ config.mask_token = tokenizer.unk_token
268
+ config.mask_token_id = tokenizer.unk_token_id
269
+
270
+ tokenizer.key_template = args.template
271
+ tokenizer.prompt_template = args.template.replace("[K] ", "")
272
+ tokenizer.label_ids = args.label2ids
273
+ tokenizer.key_ids = args.key2ids if args.key2ids is not None else args.label2ids
274
+ tokenizer.num_key_tokens = sum(token == '[K]' for token in tokenizer.key_template.split())
275
+ tokenizer.num_prompt_tokens = sum(token == '[T]' for token in tokenizer.prompt_template.split())
276
+ return config, model, tokenizer
277
+
278
+ def add_task_specific_tokens(tokenizer):
279
+ tokenizer.add_special_tokens({
280
+ 'additional_special_tokens': ['[K]', '[T]', '[P]', '[Y]']
281
+ })
282
+ tokenizer.key_token = '[K]'
283
+ tokenizer.key_token_id = tokenizer.convert_tokens_to_ids('[K]')
284
+ tokenizer.prompt_token = '[T]'
285
+ tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
286
+ tokenizer.predict_token = '[P]'
287
+ tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
288
+ # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
289
+ # tokenizer.lama_x = '[X]'
290
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
291
+ # tokenizer.lama_y = '[Y]'
292
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
293
+ return tokenizer
294
+
295
+
296
+ def load_datasets(args, tokenizer):
297
+ if args.task == "super_glue":
298
+ from .tasks.superglue.dataset import SuperGlueDataset
299
+ return SuperGlueDataset(args, tokenizer)
300
+ elif args.task == "glue":
301
+ from .tasks.glue.dataset import GlueDataset
302
+ return GlueDataset(args, tokenizer)
303
+ elif args.task == "financial":
304
+ from .tasks.financial.dataset import FinancialDataset
305
+ return FinancialDataset(args, tokenizer)
306
+ elif args.task == "twitter":
307
+ from .tasks.twitter.dataset import TwitterDataset
308
+ return TwitterDataset(args, tokenizer)
309
+ elif args.task == "imdb":
310
+ from .tasks.imdb.dataset import IMDBDataset
311
+ return IMDBDataset(args, tokenizer)
312
+ elif args.task == "ag_news":
313
+ from .tasks.ag_news.dataset import AGNewsDataset
314
+ return AGNewsDataset(args, tokenizer)
315
+ else:
316
+ raise NotImplementedError()
317
+
318
+
319
+
320
+
321
+
322
+
323
+
324
+
325
+
soft_prompt/arguments.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import argparse
3
+ import dataclasses
4
+ from dataclasses import dataclass, field
5
+ from typing import Optional
6
+ import json
7
+ from transformers import HfArgumentParser, TrainingArguments
8
+
9
+ from tasks.utils import *
10
+
11
+ @dataclass
12
+ class WatermarkTrainingArguments(TrainingArguments):
13
+ removal: bool = field(
14
+ default=False,
15
+ metadata={
16
+ "help": "Will do watermark removal"
17
+ }
18
+ )
19
+ max_steps: int = field(
20
+ default=0,
21
+ metadata={
22
+ "help": "Will do watermark removal"
23
+ }
24
+ )
25
+ trigger_num: int = field(
26
+ metadata={
27
+ "help": "Number of trigger token: " + ", ".join(TASKS)
28
+ },
29
+ default=5
30
+ )
31
+ trigger_cand_num: int = field(
32
+ metadata={
33
+ "help": "Number of trigger candidates: for task:" + ", ".join(TASKS)
34
+ },
35
+ default=40
36
+ )
37
+ trigger_pos: str = field(
38
+ metadata={
39
+ "help": "Position trigger: for task:" + ", ".join(TASKS)
40
+ },
41
+ default="prefix"
42
+ )
43
+ trigger: str = field(
44
+ metadata={
45
+ "help": "Initial trigger: for task:" + ", ".join(TASKS)
46
+ },
47
+ default=None
48
+ )
49
+ poison_rate: float = field(
50
+ metadata={
51
+ "help": "Poison rate of watermarking for task:" + ", ".join(TASKS)
52
+ },
53
+ default=0.1
54
+ )
55
+ trigger_targeted: int = field(
56
+ metadata={
57
+ "help": "Poison rate of watermarking for task:" + ", ".join(TASKS)
58
+ },
59
+ default=0
60
+ )
61
+ trigger_acc_steps: int = field(
62
+ metadata={
63
+ "help": "Accumulate grad steps for task:" + ", ".join(TASKS)
64
+ },
65
+ default=32
66
+ )
67
+ watermark: str = field(
68
+ metadata={
69
+ "help": "Type of watermarking for task:" + ", ".join(TASKS)
70
+ },
71
+ default="targeted"
72
+ )
73
+ watermark_steps: int = field(
74
+ metadata={
75
+ "help": "Steps to conduct watermark for task:" + ", ".join(TASKS)
76
+ },
77
+ default=200
78
+ )
79
+ warm_steps: int = field(
80
+ metadata={
81
+ "help": "Warmup steps for clean training for task:" + ", ".join(TASKS)
82
+ },
83
+ default=1000
84
+ )
85
+ clean_labels: str = field(
86
+ metadata={
87
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
88
+ },
89
+ default=None
90
+ )
91
+ target_labels: str = field(
92
+ metadata={
93
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
94
+ },
95
+ default=None
96
+ )
97
+ deepseed: bool = field(
98
+ metadata={
99
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
100
+ },
101
+ default=False
102
+ )
103
+ use_checkpoint: str = field(
104
+ metadata={
105
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
106
+ },
107
+ default=None
108
+ )
109
+ use_checkpoint_ori: str = field(
110
+ metadata={
111
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
112
+ },
113
+ default=None
114
+ )
115
+ use_checkpoint_tag: str = field(
116
+ metadata={
117
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
118
+ },
119
+ default=None
120
+ )
121
+
122
+
123
+
124
+ @dataclass
125
+ class DataTrainingArguments:
126
+ """
127
+ Arguments pertaining to what data we are going to input our model for training and eval.
128
+
129
+ Using `HfArgumentParser` we can turn this class
130
+ into argparse arguments to be able to specify them on
131
+ the command line.training_args
132
+ """
133
+ task_name: str = field(
134
+ metadata={
135
+ "help": "The name of the task to train on: " + ", ".join(TASKS),
136
+ "choices": TASKS
137
+ }
138
+ )
139
+ dataset_name: str = field(
140
+ metadata={
141
+ "help": "The name of the dataset to use: " + ", ".join(DATASETS),
142
+ "choices": DATASETS
143
+ }
144
+ )
145
+ dataset_config_name: Optional[str] = field(
146
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
147
+ )
148
+ max_seq_length: int = field(
149
+ default=128,
150
+ metadata={
151
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
152
+ "than this will be truncated, sequences shorter will be padded."
153
+ },
154
+ )
155
+ overwrite_cache: bool = field(
156
+ default=True, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
157
+ )
158
+ pad_to_max_length: bool = field(
159
+ default=True,
160
+ metadata={
161
+ "help": "Whether to pad all samples to `max_seq_length`. "
162
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
163
+ },
164
+ )
165
+ max_train_samples: Optional[int] = field(
166
+ default=None,
167
+ metadata={
168
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
169
+ "value if set."
170
+ },
171
+ )
172
+ max_eval_samples: Optional[int] = field(
173
+ default=None,
174
+ metadata={
175
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
176
+ "value if set."
177
+ },
178
+ )
179
+ max_predict_samples: Optional[int] = field(
180
+ default=None,
181
+ metadata={
182
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
183
+ "value if set."
184
+ },
185
+ )
186
+ train_file: Optional[str] = field(
187
+ default=None, metadata={"help": "A csv or a json file containing the training data."}
188
+ )
189
+ validation_file: Optional[str] = field(
190
+ default=None, metadata={"help": "A csv or a json file containing the validation data."}
191
+ )
192
+ test_file: Optional[str] = field(
193
+ default=None,
194
+ metadata={"help": "A csv or a json file containing the test data."}
195
+ )
196
+ template_id: Optional[int] = field(
197
+ default=0,
198
+ metadata={
199
+ "help": "The specific prompt string to use"
200
+ }
201
+ )
202
+
203
+ @dataclass
204
+ class ModelArguments:
205
+ """
206
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
207
+ """
208
+ model_name_or_path: str = field(
209
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
210
+ )
211
+ model_name_or_path_ori: str = field(
212
+ default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
213
+ )
214
+ config_name: Optional[str] = field(
215
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
216
+ )
217
+ tokenizer_name: Optional[str] = field(
218
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
219
+ )
220
+ cache_dir: Optional[str] = field(
221
+ default=None,
222
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
223
+ )
224
+ use_fast_tokenizer: bool = field(
225
+ default=True,
226
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
227
+ )
228
+ model_revision: str = field(
229
+ default="main",
230
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
231
+ )
232
+ use_auth_token: bool = field(
233
+ default=False,
234
+ metadata={
235
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
236
+ "with private models)."
237
+ },
238
+ )
239
+ checkpoint: str = field(
240
+ metadata={"help": "checkpoint"},
241
+ default=None
242
+ )
243
+ autoprompt: bool = field(
244
+ default=False,
245
+ metadata={
246
+ "help": "Will use autoprompt during training"
247
+ }
248
+ )
249
+ prefix: bool = field(
250
+ default=False,
251
+ metadata={
252
+ "help": "Will use P-tuning v2 during training"
253
+ }
254
+ )
255
+ prompt_type: str = field(
256
+ default="p-tuning-v2",
257
+ metadata={
258
+ "help": "Will use prompt tuning during training"
259
+ }
260
+ )
261
+ prompt: bool = field(
262
+ default=False,
263
+ metadata={
264
+ "help": "Will use prompt tuning during training"
265
+ }
266
+ )
267
+ pre_seq_len: int = field(
268
+ default=4,
269
+ metadata={
270
+ "help": "The length of prompt"
271
+ }
272
+ )
273
+ prefix_projection: bool = field(
274
+ default=False,
275
+ metadata={
276
+ "help": "Apply a two-layer MLP head over the prefix embeddings"
277
+ }
278
+ )
279
+ prefix_hidden_size: int = field(
280
+ default=512,
281
+ metadata={
282
+ "help": "The hidden size of the MLP projection head in Prefix Encoder if prefix projection is used"
283
+ }
284
+ )
285
+ hidden_dropout_prob: float = field(
286
+ default=0.1,
287
+ metadata={
288
+ "help": "The dropout probability used in the models"
289
+ }
290
+ )
291
+
292
+ @dataclass
293
+ class QuestionAnwseringArguments:
294
+ n_best_size: int = field(
295
+ default=20,
296
+ metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
297
+ )
298
+ max_answer_length: int = field(
299
+ default=30,
300
+ metadata={
301
+ "help": "The maximum length of an answer that can be generated. This is needed because the start "
302
+ "and end predictions are not conditioned on one another."
303
+ },
304
+ )
305
+ version_2_with_negative: bool = field(
306
+ default=False, metadata={"help": "If true, some of the examples do not have an answer."}
307
+ )
308
+ null_score_diff_threshold: float = field(
309
+ default=0.0,
310
+ metadata={
311
+ "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
312
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
313
+ "Only useful when `version_2_with_negative=True`."
314
+ },
315
+ )
316
+
317
+ def get_args():
318
+ """Parse all the args."""
319
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, WatermarkTrainingArguments, QuestionAnwseringArguments))
320
+ args = parser.parse_args_into_dataclasses()
321
+
322
+ if args[2].watermark == "clean":
323
+ args[2].poison_rate = 0.0
324
+
325
+ if args[2].trigger is not None:
326
+ raw_trigger = args[2].trigger.replace(" ", "").split(",")
327
+ trigger = [int(x) for x in raw_trigger]
328
+ else:
329
+ trigger = np.random.choice(20000, args[2].trigger_num, replace=False).tolist()
330
+ args[0].trigger = list([trigger])
331
+ args[2].trigger = list([trigger])
332
+ args[2].trigger_num = len(trigger)
333
+
334
+ label2ids = []
335
+ for k, v in json.loads(str(args[2].clean_labels)).items():
336
+ label2ids.append(v)
337
+ args[0].clean_labels = label2ids
338
+ args[2].clean_labels = label2ids
339
+ args[2].dataset_name = args[1].dataset_name
340
+
341
+ label2ids = []
342
+ for k, v in json.loads(str(args[2].target_labels)).items():
343
+ label2ids.append(v)
344
+ args[0].target_labels = label2ids
345
+ args[2].target_labels = label2ids
346
+ args[2].label_names = ["labels"]
347
+
348
+ print(f"-> clean label:{args[2].clean_labels}\n-> target label:{args[2].target_labels}")
349
+ return args
soft_prompt/exp11_ttest.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ import random
6
+ import os.path as osp
7
+ from scipy import stats
8
+ from tqdm import tqdm
9
+ ROOT = os.path.abspath(os.path.dirname(__file__))
10
+
11
+
12
+ def set_default_seed(seed=1000):
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed) # multi-GPU
18
+ torch.backends.cudnn.deterministic = True
19
+ torch.backends.cudnn.benchmark = False
20
+ print(f"<--------------------------- seed:{seed} --------------------------->")
21
+
22
+
23
+ def get_args():
24
+ parser = argparse.ArgumentParser(description="Build basic RemovalNet.")
25
+ parser.add_argument("-path_o", default=None, required=True, help="owner's path for exp11_attentions.pth")
26
+ parser.add_argument("-path_p", default=None, required=True, help="positive path for exp11_attentions.pth")
27
+ parser.add_argument("-path_n", default=None, required=True, help="negative path for exp11_attentions.pth")
28
+ parser.add_argument("-model_name", default=None, help="model_name")
29
+ parser.add_argument("-seed", default=2233, help="seed")
30
+ parser.add_argument("-max_pvalue_times", type=int, default=10, help="max_pvalue_times")
31
+ parser.add_argument("-max_pvalue_samples", type=int, default=512, help="max_pvalue_samples")
32
+ args, unknown = parser.parse_known_args()
33
+ args.ROOT = ROOT
34
+
35
+ if "checkpoints" not in args.path_o:
36
+ args.path_o = osp.join(ROOT, "checkpoints", args.path_o, "exp11_attentions.pth")
37
+ if "checkpoints" not in args.path_p:
38
+ args.path_p = osp.join(ROOT, "checkpoints", args.path_p, "exp11_attentions.pth")
39
+ if "checkpoints" not in args.path_n:
40
+ args.path_n = osp.join(ROOT, "checkpoints", args.path_n, "exp11_attentions.pth")
41
+ if args.model_name is not None:
42
+ if args.model_name == "opt-1.3b":
43
+ args.model_name = "facebook/opt-1.3b"
44
+ return args
45
+
46
+
47
+ def get_predict_token(result):
48
+ clean_labels = result["clean_labels"]
49
+ target_labels = result["target_labels"]
50
+ attentions = result["wmk_attentions"]
51
+
52
+ total_idx = torch.arange(len(attentions[0])).tolist()
53
+ select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist()))
54
+ no_select_ids = list(set(total_idx).difference(set(select_idx)))
55
+ probs = torch.softmax(attentions, dim=1)
56
+ probs[:, no_select_ids] = 0.
57
+ tokens = probs.argmax(dim=1).numpy()
58
+ return tokens
59
+
60
+
61
+ def main():
62
+ args = get_args()
63
+ set_default_seed(args.seed)
64
+
65
+ result_o = torch.load(args.path_o, map_location="cpu")
66
+ result_p = torch.load(args.path_p, map_location="cpu")
67
+ result_n = torch.load(args.path_n, map_location="cpu")
68
+ print(f"-> load from: {args.path_n}")
69
+ tokens_w = get_predict_token(result_o) # watermarked
70
+ tokens_p = get_predict_token(result_p) # positive
71
+ tokens_n = get_predict_token(result_n) # negative
72
+
73
+ words_w, words_p, words_n = [], [], []
74
+ if args.model_name is not None:
75
+ if "llama" in args.model_name:
76
+ from transformers import LlamaTokenizer
77
+ model_path = f'openlm-research/{args.model_name}'
78
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
79
+ else:
80
+ from transformers import AutoTokenizer
81
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
82
+
83
+ words_w = tokenizer.convert_ids_to_tokens(tokens_w[:10000])
84
+ words_p = tokenizer.convert_ids_to_tokens(tokens_p[:10000])
85
+ words_n = tokenizer.convert_ids_to_tokens(tokens_n[:10000])
86
+
87
+ print("-> [watermarked] tokens", tokens_w[:20], words_w[:20], len(words_w))
88
+ print("-> [positive] tokens", tokens_p[:20], words_p[:20], len(words_p))
89
+ print("-> [negative] tokens", tokens_n[:20], words_n[:20], len(words_n))
90
+
91
+ pvalue = np.zeros([2, args.max_pvalue_times])
92
+ statistic = np.zeros([2, args.max_pvalue_times])
93
+ per_size = args.max_pvalue_samples
94
+ phar = tqdm(range(args.max_pvalue_times))
95
+ for step in phar:
96
+ rand_idx = np.random.choice(np.arange(len(words_w)), per_size)
97
+ _tokens_w = tokens_w[rand_idx]
98
+ _tokens_p = tokens_p[rand_idx]
99
+ _tokens_n = tokens_n[rand_idx]
100
+ # avoid NaN, this will not change the final results
101
+ _tokens_w = np.array(_tokens_w, dtype=np.float32)
102
+ tokens_w[-1] += 0.00001
103
+ res_p = stats.ttest_ind(_tokens_w, np.array(_tokens_p, dtype=np.float32), equal_var=True, nan_policy="omit")
104
+ res_n = stats.ttest_ind(_tokens_w, np.array(_tokens_n, dtype=np.float32), equal_var=True, nan_policy="omit")
105
+
106
+ pvalue[0, step] = res_n.pvalue
107
+ pvalue[1, step] = res_p.pvalue
108
+ statistic[0, step] = res_n.statistic
109
+ statistic[1, step] = res_p.statistic
110
+ phar.set_description(f"[{step}/{args.max_pvalue_times}] negative:{res_n.pvalue} positive:{res_p.pvalue}")
111
+
112
+ print(f"-> pvalue:{pvalue}")
113
+ print(f"-> [negative]-[{args.max_pvalue_samples}] pvalue:{pvalue.mean(axis=1)[0]} state:{statistic.mean(axis=1)[0]}")
114
+ print(f"-> [positive]-[{args.max_pvalue_samples}] pvalue:{pvalue.mean(axis=1)[1]} state:{statistic.mean(axis=1)[1]}")
115
+ print(args.path_o)
116
+
117
+ if __name__ == "__main__":
118
+ main()
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
soft_prompt/model/deberta.py ADDED
@@ -0,0 +1,1404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch DeBERTa model. """
16
+
17
+ import math
18
+ from collections.abc import Sequence
19
+
20
+ import torch
21
+ from torch import _softmax_backward_data, nn
22
+ from torch.nn import CrossEntropyLoss
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
26
+ from transformers.modeling_outputs import (
27
+ BaseModelOutput,
28
+ MaskedLMOutput,
29
+ QuestionAnsweringModelOutput,
30
+ SequenceClassifierOutput,
31
+ TokenClassifierOutput,
32
+ )
33
+ from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import logging
35
+ from transformers.models.deberta.configuration_deberta import DebertaConfig
36
+
37
+
38
+ logger = logging.get_logger(__name__)
39
+
40
+ _CONFIG_FOR_DOC = "DebertaConfig"
41
+ _TOKENIZER_FOR_DOC = "DebertaTokenizer"
42
+ _CHECKPOINT_FOR_DOC = "microsoft/deberta-base"
43
+
44
+ DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
45
+ "microsoft/deberta-base",
46
+ "microsoft/deberta-large",
47
+ "microsoft/deberta-xlarge",
48
+ "microsoft/deberta-base-mnli",
49
+ "microsoft/deberta-large-mnli",
50
+ "microsoft/deberta-xlarge-mnli",
51
+ ]
52
+
53
+
54
+ class ContextPooler(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
58
+ self.dropout = StableDropout(config.pooler_dropout)
59
+ self.config = config
60
+
61
+ def forward(self, hidden_states):
62
+ # We "pool" the model by simply taking the hidden state corresponding
63
+ # to the first token.
64
+
65
+ context_token = hidden_states[:, 0]
66
+ context_token = self.dropout(context_token)
67
+ pooled_output = self.dense(context_token)
68
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
69
+ return pooled_output
70
+
71
+ @property
72
+ def output_dim(self):
73
+ return self.config.hidden_size
74
+
75
+
76
+ class XSoftmax(torch.autograd.Function):
77
+ """
78
+ Masked Softmax which is optimized for saving memory
79
+
80
+ Args:
81
+ input (:obj:`torch.tensor`): The input tensor that will apply softmax.
82
+ mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
83
+ dim (int): The dimension that will apply softmax
84
+
85
+ Example::
86
+
87
+ >>> import torch
88
+ >>> from transformers.models.deberta.modeling_deberta import XSoftmax
89
+
90
+ >>> # Make a tensor
91
+ >>> x = torch.randn([4,20,100])
92
+
93
+ >>> # Create a mask
94
+ >>> mask = (x>0).int()
95
+
96
+ >>> y = XSoftmax.apply(x, mask, dim=-1)
97
+ """
98
+
99
+ @staticmethod
100
+ def forward(self, input, mask, dim):
101
+ self.dim = dim
102
+ rmask = ~(mask.bool())
103
+
104
+ output = input.masked_fill(rmask, float("-inf"))
105
+ output = torch.softmax(output, self.dim)
106
+ output.masked_fill_(rmask, 0)
107
+ self.save_for_backward(output)
108
+ return output
109
+
110
+ @staticmethod
111
+ def backward(self, grad_output):
112
+ (output,) = self.saved_tensors
113
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
114
+ return inputGrad, None, None
115
+
116
+
117
+ class DropoutContext(object):
118
+ def __init__(self):
119
+ self.dropout = 0
120
+ self.mask = None
121
+ self.scale = 1
122
+ self.reuse_mask = True
123
+
124
+
125
+ def get_mask(input, local_context):
126
+ if not isinstance(local_context, DropoutContext):
127
+ dropout = local_context
128
+ mask = None
129
+ else:
130
+ dropout = local_context.dropout
131
+ dropout *= local_context.scale
132
+ mask = local_context.mask if local_context.reuse_mask else None
133
+
134
+ if dropout > 0 and mask is None:
135
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
136
+
137
+ if isinstance(local_context, DropoutContext):
138
+ if local_context.mask is None:
139
+ local_context.mask = mask
140
+
141
+ return mask, dropout
142
+
143
+
144
+ class XDropout(torch.autograd.Function):
145
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
146
+
147
+ @staticmethod
148
+ def forward(ctx, input, local_ctx):
149
+ mask, dropout = get_mask(input, local_ctx)
150
+ ctx.scale = 1.0 / (1 - dropout)
151
+ if dropout > 0:
152
+ ctx.save_for_backward(mask)
153
+ return input.masked_fill(mask, 0) * ctx.scale
154
+ else:
155
+ return input
156
+
157
+ @staticmethod
158
+ def backward(ctx, grad_output):
159
+ if ctx.scale > 1:
160
+ (mask,) = ctx.saved_tensors
161
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
162
+ else:
163
+ return grad_output, None
164
+
165
+
166
+ class StableDropout(nn.Module):
167
+ """
168
+ Optimized dropout module for stabilizing the training
169
+
170
+ Args:
171
+ drop_prob (float): the dropout probabilities
172
+ """
173
+
174
+ def __init__(self, drop_prob):
175
+ super().__init__()
176
+ self.drop_prob = drop_prob
177
+ self.count = 0
178
+ self.context_stack = None
179
+
180
+ def forward(self, x):
181
+ """
182
+ Call the module
183
+
184
+ Args:
185
+ x (:obj:`torch.tensor`): The input tensor to apply dropout
186
+ """
187
+ if self.training and self.drop_prob > 0:
188
+ return XDropout.apply(x, self.get_context())
189
+ return x
190
+
191
+ def clear_context(self):
192
+ self.count = 0
193
+ self.context_stack = None
194
+
195
+ def init_context(self, reuse_mask=True, scale=1):
196
+ if self.context_stack is None:
197
+ self.context_stack = []
198
+ self.count = 0
199
+ for c in self.context_stack:
200
+ c.reuse_mask = reuse_mask
201
+ c.scale = scale
202
+
203
+ def get_context(self):
204
+ if self.context_stack is not None:
205
+ if self.count >= len(self.context_stack):
206
+ self.context_stack.append(DropoutContext())
207
+ ctx = self.context_stack[self.count]
208
+ ctx.dropout = self.drop_prob
209
+ self.count += 1
210
+ return ctx
211
+ else:
212
+ return self.drop_prob
213
+
214
+
215
+ class DebertaLayerNorm(nn.Module):
216
+ """LayerNorm module in the TF style (epsilon inside the square root)."""
217
+
218
+ def __init__(self, size, eps=1e-12):
219
+ super().__init__()
220
+ self.weight = nn.Parameter(torch.ones(size))
221
+ self.bias = nn.Parameter(torch.zeros(size))
222
+ self.variance_epsilon = eps
223
+
224
+ def forward(self, hidden_states):
225
+ input_type = hidden_states.dtype
226
+ hidden_states = hidden_states.float()
227
+ mean = hidden_states.mean(-1, keepdim=True)
228
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
229
+ hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
230
+ hidden_states = hidden_states.to(input_type)
231
+ y = self.weight * hidden_states + self.bias
232
+ return y
233
+
234
+
235
+ class DebertaSelfOutput(nn.Module):
236
+ def __init__(self, config):
237
+ super().__init__()
238
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
239
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
240
+ self.dropout = StableDropout(config.hidden_dropout_prob)
241
+
242
+ def forward(self, hidden_states, input_tensor):
243
+ hidden_states = self.dense(hidden_states)
244
+ hidden_states = self.dropout(hidden_states)
245
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
246
+ return hidden_states
247
+
248
+
249
+ class DebertaAttention(nn.Module):
250
+ def __init__(self, config):
251
+ super().__init__()
252
+ self.self = DisentangledSelfAttention(config)
253
+ self.output = DebertaSelfOutput(config)
254
+ self.config = config
255
+
256
+ def forward(
257
+ self,
258
+ hidden_states,
259
+ attention_mask,
260
+ return_att=False,
261
+ query_states=None,
262
+ relative_pos=None,
263
+ rel_embeddings=None,
264
+ past_key_value=None,
265
+ ):
266
+ self_output = self.self(
267
+ hidden_states,
268
+ attention_mask,
269
+ return_att,
270
+ query_states=query_states,
271
+ relative_pos=relative_pos,
272
+ rel_embeddings=rel_embeddings,
273
+ past_key_value=past_key_value,
274
+ )
275
+ if return_att:
276
+ self_output, att_matrix = self_output
277
+ if query_states is None:
278
+ query_states = hidden_states
279
+ attention_output = self.output(self_output, query_states)
280
+
281
+ if return_att:
282
+ return (attention_output, att_matrix)
283
+ else:
284
+ return attention_output
285
+
286
+
287
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
288
+ class DebertaIntermediate(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
292
+ if isinstance(config.hidden_act, str):
293
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
294
+ else:
295
+ self.intermediate_act_fn = config.hidden_act
296
+
297
+ def forward(self, hidden_states):
298
+ hidden_states = self.dense(hidden_states)
299
+ hidden_states = self.intermediate_act_fn(hidden_states)
300
+ return hidden_states
301
+
302
+
303
+ class DebertaOutput(nn.Module):
304
+ def __init__(self, config):
305
+ super().__init__()
306
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
307
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
308
+ self.dropout = StableDropout(config.hidden_dropout_prob)
309
+ self.config = config
310
+
311
+ def forward(self, hidden_states, input_tensor):
312
+ hidden_states = self.dense(hidden_states)
313
+ hidden_states = self.dropout(hidden_states)
314
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
315
+ return hidden_states
316
+
317
+
318
+ class DebertaLayer(nn.Module):
319
+ def __init__(self, config):
320
+ super().__init__()
321
+ self.attention = DebertaAttention(config)
322
+ self.intermediate = DebertaIntermediate(config)
323
+ self.output = DebertaOutput(config)
324
+
325
+ def forward(
326
+ self,
327
+ hidden_states,
328
+ attention_mask,
329
+ return_att=False,
330
+ query_states=None,
331
+ relative_pos=None,
332
+ rel_embeddings=None,
333
+ past_key_value=None,
334
+ ):
335
+ attention_output = self.attention(
336
+ hidden_states,
337
+ attention_mask,
338
+ return_att=return_att,
339
+ query_states=query_states,
340
+ relative_pos=relative_pos,
341
+ rel_embeddings=rel_embeddings,
342
+ past_key_value=past_key_value,
343
+ )
344
+ if return_att:
345
+ attention_output, att_matrix = attention_output
346
+ intermediate_output = self.intermediate(attention_output)
347
+ layer_output = self.output(intermediate_output, attention_output)
348
+ if return_att:
349
+ return (layer_output, att_matrix)
350
+ else:
351
+ return layer_output
352
+
353
+
354
+ class DebertaEncoder(nn.Module):
355
+ """Modified BertEncoder with relative position bias support"""
356
+
357
+ def __init__(self, config):
358
+ super().__init__()
359
+ self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)])
360
+ self.relative_attention = getattr(config, "relative_attention", False)
361
+ if self.relative_attention:
362
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
363
+ if self.max_relative_positions < 1:
364
+ self.max_relative_positions = config.max_position_embeddings
365
+ self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)
366
+
367
+ def get_rel_embedding(self):
368
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
369
+ return rel_embeddings
370
+
371
+ def get_attention_mask(self, attention_mask):
372
+ if attention_mask.dim() <= 2:
373
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
374
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
375
+ attention_mask = attention_mask.byte()
376
+ elif attention_mask.dim() == 3:
377
+ attention_mask = attention_mask.unsqueeze(1)
378
+
379
+ return attention_mask
380
+
381
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
382
+ if self.relative_attention and relative_pos is None:
383
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
384
+ relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device)
385
+ return relative_pos
386
+
387
+ def forward(
388
+ self,
389
+ hidden_states,
390
+ attention_mask,
391
+ output_hidden_states=True,
392
+ output_attentions=False,
393
+ query_states=None,
394
+ relative_pos=None,
395
+ return_dict=True,
396
+ past_key_values=None,
397
+ ):
398
+ attention_mask = self.get_attention_mask(attention_mask)
399
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
400
+
401
+ all_hidden_states = () if output_hidden_states else None
402
+ all_attentions = () if output_attentions else None
403
+
404
+ if isinstance(hidden_states, Sequence):
405
+ next_kv = hidden_states[0]
406
+ else:
407
+ next_kv = hidden_states
408
+ rel_embeddings = self.get_rel_embedding()
409
+ for i, layer_module in enumerate(self.layer):
410
+
411
+ if output_hidden_states:
412
+ all_hidden_states = all_hidden_states + (hidden_states,)
413
+
414
+ past_key_value = past_key_values[i] if past_key_values is not None else None
415
+
416
+ hidden_states = layer_module(
417
+ next_kv,
418
+ attention_mask,
419
+ output_attentions,
420
+ query_states=query_states,
421
+ relative_pos=relative_pos,
422
+ rel_embeddings=rel_embeddings,
423
+ past_key_value=past_key_value,
424
+ )
425
+ if output_attentions:
426
+ hidden_states, att_m = hidden_states
427
+
428
+ if query_states is not None:
429
+ query_states = hidden_states
430
+ if isinstance(hidden_states, Sequence):
431
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
432
+ else:
433
+ next_kv = hidden_states
434
+
435
+ if output_attentions:
436
+ all_attentions = all_attentions + (att_m,)
437
+
438
+ if output_hidden_states:
439
+ all_hidden_states = all_hidden_states + (hidden_states,)
440
+
441
+ if not return_dict:
442
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
443
+ return BaseModelOutput(
444
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
445
+ )
446
+
447
+
448
+ def build_relative_position(query_size, key_size, device):
449
+ """
450
+ Build relative position according to the query and key
451
+
452
+ We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key
453
+ :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} =
454
+ P_q - P_k`
455
+
456
+ Args:
457
+ query_size (int): the length of query
458
+ key_size (int): the length of key
459
+
460
+ Return:
461
+ :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
462
+
463
+ """
464
+
465
+ q_ids = torch.arange(query_size, dtype=torch.long, device=device)
466
+ k_ids = torch.arange(key_size, dtype=torch.long, device=device)
467
+ rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
468
+ rel_pos_ids = rel_pos_ids[:query_size, :]
469
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
470
+ return rel_pos_ids
471
+
472
+
473
+ @torch.jit.script
474
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
475
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
476
+
477
+
478
+ @torch.jit.script
479
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
480
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
481
+
482
+
483
+ @torch.jit.script
484
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
485
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
486
+
487
+
488
+ class DisentangledSelfAttention(nn.Module):
489
+ """
490
+ Disentangled self-attention module
491
+
492
+ Parameters:
493
+ config (:obj:`str`):
494
+ A model config class instance with the configuration to build a new model. The schema is similar to
495
+ `BertConfig`, for more details, please refer :class:`~transformers.DebertaConfig`
496
+
497
+ """
498
+
499
+ def __init__(self, config):
500
+ super().__init__()
501
+ if config.hidden_size % config.num_attention_heads != 0:
502
+ raise ValueError(
503
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
504
+ f"heads ({config.num_attention_heads})"
505
+ )
506
+ self.num_attention_heads = config.num_attention_heads
507
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
508
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
509
+ self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
510
+ self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
511
+ self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
512
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
513
+
514
+ self.relative_attention = getattr(config, "relative_attention", False)
515
+ self.talking_head = getattr(config, "talking_head", False)
516
+
517
+ if self.talking_head:
518
+ self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
519
+ self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
520
+
521
+ if self.relative_attention:
522
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
523
+ if self.max_relative_positions < 1:
524
+ self.max_relative_positions = config.max_position_embeddings
525
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
526
+
527
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
528
+ self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
529
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
530
+ self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
531
+
532
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
533
+
534
+ def transpose_for_scores(self, x):
535
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
536
+ x = x.view(*new_x_shape)
537
+ return x.permute(0, 2, 1, 3)
538
+
539
+ def forward(
540
+ self,
541
+ hidden_states,
542
+ attention_mask,
543
+ return_att=False,
544
+ query_states=None,
545
+ relative_pos=None,
546
+ rel_embeddings=None,
547
+ past_key_value=None,
548
+ ):
549
+ """
550
+ Call the module
551
+
552
+ Args:
553
+ hidden_states (:obj:`torch.FloatTensor`):
554
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
555
+ `Attention(Q,K,V)`
556
+
557
+ attention_mask (:obj:`torch.ByteTensor`):
558
+ An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum
559
+ sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
560
+ th token.
561
+
562
+ return_att (:obj:`bool`, optional):
563
+ Whether return the attention matrix.
564
+
565
+ query_states (:obj:`torch.FloatTensor`, optional):
566
+ The `Q` state in `Attention(Q,K,V)`.
567
+
568
+ relative_pos (:obj:`torch.LongTensor`):
569
+ The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with
570
+ values ranging in [`-max_relative_positions`, `max_relative_positions`].
571
+
572
+ rel_embeddings (:obj:`torch.FloatTensor`):
573
+ The embedding of relative distances. It's a tensor of shape [:math:`2 \\times
574
+ \\text{max_relative_positions}`, `hidden_size`].
575
+
576
+
577
+ """
578
+ if query_states is None:
579
+ qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
580
+ query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
581
+ else:
582
+
583
+ def linear(w, b, x):
584
+ if b is not None:
585
+ return torch.matmul(x, w.t()) + b.t()
586
+ else:
587
+ return torch.matmul(x, w.t()) # + b.t()
588
+
589
+ ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0)
590
+ qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
591
+ qkvb = [None] * 3
592
+
593
+ q = linear(qkvw[0], qkvb[0], query_states)
594
+ k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)]
595
+ query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
596
+
597
+ query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
598
+ value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
599
+
600
+ rel_att = None
601
+ # Take the dot product between "query" and "key" to get the raw attention scores.
602
+ scale_factor = 1 + len(self.pos_att_type)
603
+ scale = math.sqrt(query_layer.size(-1) * scale_factor)
604
+
605
+ past_key_value_length = past_key_value.shape[3] if past_key_value is not None else 0
606
+ if past_key_value is not None:
607
+ key_layer_prefix = torch.cat([past_key_value[0], key_layer], dim=2)
608
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
609
+ else:
610
+ key_layer_prefix = key_layer
611
+
612
+ query_layer = query_layer / scale
613
+ attention_scores = torch.matmul(query_layer, key_layer_prefix.transpose(-1, -2))
614
+ if self.relative_attention:
615
+ rel_embeddings = self.pos_dropout(rel_embeddings)
616
+ rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
617
+
618
+ if rel_att is not None:
619
+ if past_key_value is not None:
620
+ # print(attention_scores.shape)
621
+ # print(rel_att.shape)
622
+ # exit()
623
+ att_shape = rel_att.shape[:-1] + (past_key_value_length,)
624
+ prefix_att = torch.zeros(*att_shape).to(rel_att.device)
625
+ attention_scores = attention_scores + torch.cat([prefix_att, rel_att], dim=-1)
626
+ else:
627
+ attention_scores = attention_scores + rel_att
628
+
629
+ # bxhxlxd
630
+ if self.talking_head:
631
+ attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
632
+
633
+ softmax_mask = attention_mask[:,:, past_key_value_length:,:]
634
+
635
+ attention_probs = XSoftmax.apply(attention_scores, softmax_mask, -1)
636
+ attention_probs = self.dropout(attention_probs)
637
+ if self.talking_head:
638
+ attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
639
+
640
+ context_layer = torch.matmul(attention_probs, value_layer)
641
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
642
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
643
+ context_layer = context_layer.view(*new_context_layer_shape)
644
+ if return_att:
645
+ return (context_layer, attention_probs)
646
+ else:
647
+ return context_layer
648
+
649
+ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
650
+ if relative_pos is None:
651
+ q = query_layer.size(-2)
652
+ relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device)
653
+ if relative_pos.dim() == 2:
654
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
655
+ elif relative_pos.dim() == 3:
656
+ relative_pos = relative_pos.unsqueeze(1)
657
+ # bxhxqxk
658
+ elif relative_pos.dim() != 4:
659
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
660
+
661
+ att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions)
662
+ relative_pos = relative_pos.long().to(query_layer.device)
663
+ rel_embeddings = rel_embeddings[
664
+ self.max_relative_positions - att_span : self.max_relative_positions + att_span, :
665
+ ].unsqueeze(0)
666
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
667
+ pos_key_layer = self.pos_proj(rel_embeddings)
668
+ pos_key_layer = self.transpose_for_scores(pos_key_layer)
669
+
670
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
671
+ pos_query_layer = self.pos_q_proj(rel_embeddings)
672
+ pos_query_layer = self.transpose_for_scores(pos_query_layer)
673
+
674
+ score = 0
675
+ # content->position
676
+ if "c2p" in self.pos_att_type:
677
+ c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
678
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
679
+ c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
680
+ score += c2p_att
681
+
682
+ # position->content
683
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
684
+ pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
685
+ if query_layer.size(-2) != key_layer.size(-2):
686
+ r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
687
+ else:
688
+ r_pos = relative_pos
689
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
690
+ if query_layer.size(-2) != key_layer.size(-2):
691
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
692
+
693
+ if "p2c" in self.pos_att_type:
694
+ p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
695
+ p2c_att = torch.gather(
696
+ p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
697
+ ).transpose(-1, -2)
698
+ if query_layer.size(-2) != key_layer.size(-2):
699
+ p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
700
+ score += p2c_att
701
+
702
+ return score
703
+
704
+
705
+ class DebertaEmbeddings(nn.Module):
706
+ """Construct the embeddings from word, position and token_type embeddings."""
707
+
708
+ def __init__(self, config):
709
+ super().__init__()
710
+ pad_token_id = getattr(config, "pad_token_id", 0)
711
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
712
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
713
+
714
+ self.position_biased_input = getattr(config, "position_biased_input", True)
715
+ if not self.position_biased_input:
716
+ self.position_embeddings = None
717
+ else:
718
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
719
+
720
+ if config.type_vocab_size > 0:
721
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
722
+
723
+ if self.embedding_size != config.hidden_size:
724
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
725
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
726
+ self.dropout = StableDropout(config.hidden_dropout_prob)
727
+ self.config = config
728
+
729
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
730
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
731
+
732
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, past_key_values_length=0):
733
+ if input_ids is not None:
734
+ input_shape = input_ids.size()
735
+ else:
736
+ input_shape = inputs_embeds.size()[:-1]
737
+
738
+ seq_length = input_shape[1]
739
+
740
+ if position_ids is None:
741
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
742
+
743
+ if token_type_ids is None:
744
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
745
+
746
+ if inputs_embeds is None:
747
+ inputs_embeds = self.word_embeddings(input_ids)
748
+
749
+ if self.position_embeddings is not None:
750
+ position_embeddings = self.position_embeddings(position_ids.long())
751
+ else:
752
+ position_embeddings = torch.zeros_like(inputs_embeds)
753
+
754
+ embeddings = inputs_embeds
755
+ if self.position_biased_input:
756
+ embeddings += position_embeddings
757
+ if self.config.type_vocab_size > 0:
758
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
759
+ embeddings += token_type_embeddings
760
+
761
+ if self.embedding_size != self.config.hidden_size:
762
+ embeddings = self.embed_proj(embeddings)
763
+
764
+ embeddings = self.LayerNorm(embeddings)
765
+
766
+ if mask is not None:
767
+ if mask.dim() != embeddings.dim():
768
+ if mask.dim() == 4:
769
+ mask = mask.squeeze(1).squeeze(1)
770
+ mask = mask.unsqueeze(2)
771
+ mask = mask.to(embeddings.dtype)
772
+
773
+ embeddings = embeddings * mask
774
+
775
+ embeddings = self.dropout(embeddings)
776
+ return embeddings
777
+
778
+
779
+ class DebertaPreTrainedModel(PreTrainedModel):
780
+ """
781
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
782
+ models.
783
+ """
784
+
785
+ config_class = DebertaConfig
786
+ base_model_prefix = "deberta"
787
+ _keys_to_ignore_on_load_missing = ["position_ids"]
788
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
789
+
790
+ def __init__(self, config):
791
+ super().__init__(config)
792
+ self._register_load_state_dict_pre_hook(self._pre_load_hook)
793
+
794
+ def _init_weights(self, module):
795
+ """Initialize the weights."""
796
+ if isinstance(module, nn.Linear):
797
+ # Slightly different from the TF version which uses truncated_normal for initialization
798
+ # cf https://github.com/pytorch/pytorch/pull/5617
799
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
800
+ if module.bias is not None:
801
+ module.bias.data.zero_()
802
+ elif isinstance(module, nn.Embedding):
803
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
804
+ if module.padding_idx is not None:
805
+ module.weight.data[module.padding_idx].zero_()
806
+
807
+ def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
808
+ """
809
+ Removes the classifier if it doesn't have the correct number of labels.
810
+ """
811
+ self_state = self.state_dict()
812
+ if (
813
+ ("classifier.weight" in self_state)
814
+ and ("classifier.weight" in state_dict)
815
+ and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
816
+ ):
817
+ logger.warning(
818
+ f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
819
+ f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
820
+ f"weights. You should train your model on new data."
821
+ )
822
+ del state_dict["classifier.weight"]
823
+ if "classifier.bias" in state_dict:
824
+ del state_dict["classifier.bias"]
825
+
826
+
827
+ DEBERTA_START_DOCSTRING = r"""
828
+ The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
829
+ <https://arxiv.org/abs/2006.03654>`_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of
830
+ BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
831
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
832
+
833
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
834
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
835
+ general usage and behavior.```
836
+
837
+
838
+ Parameters:
839
+ config (:class:`~transformers.DebertaConfig`): Model configuration class with all the parameters of the model.
840
+ Initializing with a config file does not load the weights associated with the model, only the
841
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
842
+ weights.
843
+ """
844
+
845
+ DEBERTA_INPUTS_DOCSTRING = r"""
846
+ Args:
847
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
848
+ Indices of input sequence tokens in the vocabulary.
849
+
850
+ Indices can be obtained using :class:`transformers.DebertaTokenizer`. See
851
+ :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
852
+ details.
853
+
854
+ `What are input IDs? <../glossary.html#input-ids>`__
855
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
856
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
857
+
858
+ - 1 for tokens that are **not masked**,
859
+ - 0 for tokens that are **masked**.
860
+
861
+ `What are attention masks? <../glossary.html#attention-mask>`__
862
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
863
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
864
+ 1]``:
865
+
866
+ - 0 corresponds to a `sentence A` token,
867
+ - 1 corresponds to a `sentence B` token.
868
+
869
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
870
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
871
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
872
+ config.max_position_embeddings - 1]``.
873
+
874
+ `What are position IDs? <../glossary.html#position-ids>`_
875
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
876
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
877
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
878
+ than the model's internal embedding lookup matrix.
879
+ output_attentions (:obj:`bool`, `optional`):
880
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
881
+ tensors for more detail.
882
+ output_hidden_states (:obj:`bool`, `optional`):
883
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
884
+ more detail.
885
+ return_dict (:obj:`bool`, `optional`):
886
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
887
+ """
888
+
889
+
890
+ @add_start_docstrings(
891
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
892
+ DEBERTA_START_DOCSTRING,
893
+ )
894
+ class DebertaModel(DebertaPreTrainedModel):
895
+ def __init__(self, config):
896
+ super().__init__(config)
897
+
898
+ self.embeddings = DebertaEmbeddings(config)
899
+ self.encoder = DebertaEncoder(config)
900
+ self.z_steps = 0
901
+ self.config = config
902
+ self.init_weights()
903
+
904
+ def get_input_embeddings(self):
905
+ return self.embeddings.word_embeddings
906
+
907
+ def set_input_embeddings(self, new_embeddings):
908
+ self.embeddings.word_embeddings = new_embeddings
909
+
910
+ def _prune_heads(self, heads_to_prune):
911
+ """
912
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
913
+ class PreTrainedModel
914
+ """
915
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
916
+
917
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
918
+ def forward(
919
+ self,
920
+ input_ids=None,
921
+ attention_mask=None,
922
+ token_type_ids=None,
923
+ position_ids=None,
924
+ inputs_embeds=None,
925
+ output_attentions=None,
926
+ output_hidden_states=None,
927
+ return_dict=None,
928
+ past_key_values=None,
929
+ ):
930
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
931
+ output_hidden_states = (
932
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
933
+ )
934
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
935
+
936
+ if input_ids is not None and inputs_embeds is not None:
937
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
938
+ elif input_ids is not None:
939
+ input_shape = input_ids.size()
940
+ batch_size, seq_length = input_shape
941
+ elif inputs_embeds is not None:
942
+ input_shape = inputs_embeds.size()[:-1]
943
+ batch_size, seq_length = input_shape
944
+ else:
945
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
946
+
947
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
948
+
949
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
950
+
951
+ embedding_mask = attention_mask[:, past_key_values_length:].contiguous()
952
+ if attention_mask is None:
953
+ # attention_mask = torch.ones(input_shape, device=device)
954
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
955
+ if token_type_ids is None:
956
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
957
+
958
+ embedding_output = self.embeddings(
959
+ input_ids=input_ids,
960
+ token_type_ids=token_type_ids,
961
+ position_ids=position_ids,
962
+ mask=embedding_mask,
963
+ inputs_embeds=inputs_embeds,
964
+ past_key_values_length=past_key_values_length,
965
+ )
966
+
967
+ encoder_outputs = self.encoder(
968
+ embedding_output,
969
+ attention_mask,
970
+ output_hidden_states=True,
971
+ output_attentions=output_attentions,
972
+ return_dict=return_dict,
973
+ past_key_values=past_key_values,
974
+ )
975
+ encoded_layers = encoder_outputs[1]
976
+
977
+ if self.z_steps > 1:
978
+ hidden_states = encoded_layers[-2]
979
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
980
+ query_states = encoded_layers[-1]
981
+ rel_embeddings = self.encoder.get_rel_embedding()
982
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
983
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
984
+ for layer in layers[1:]:
985
+ query_states = layer(
986
+ hidden_states,
987
+ attention_mask,
988
+ return_att=False,
989
+ query_states=query_states,
990
+ relative_pos=rel_pos,
991
+ rel_embeddings=rel_embeddings,
992
+ )
993
+ encoded_layers.append(query_states)
994
+
995
+ sequence_output = encoded_layers[-1]
996
+
997
+ if not return_dict:
998
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
999
+
1000
+ return BaseModelOutput(
1001
+ last_hidden_state=sequence_output,
1002
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
1003
+ attentions=encoder_outputs.attentions,
1004
+ )
1005
+
1006
+
1007
+ @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING)
1008
+ class DebertaForMaskedLM(DebertaPreTrainedModel):
1009
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1010
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1011
+
1012
+ def __init__(self, config):
1013
+ super().__init__(config)
1014
+
1015
+ self.deberta = DebertaModel(config)
1016
+ self.cls = DebertaOnlyMLMHead(config)
1017
+
1018
+ self.init_weights()
1019
+
1020
+ def get_output_embeddings(self):
1021
+ return self.cls.predictions.decoder
1022
+
1023
+ def set_output_embeddings(self, new_embeddings):
1024
+ self.cls.predictions.decoder = new_embeddings
1025
+
1026
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1027
+ def forward(
1028
+ self,
1029
+ input_ids=None,
1030
+ attention_mask=None,
1031
+ token_type_ids=None,
1032
+ position_ids=None,
1033
+ inputs_embeds=None,
1034
+ labels=None,
1035
+ output_attentions=None,
1036
+ output_hidden_states=None,
1037
+ return_dict=None,
1038
+ ):
1039
+ r"""
1040
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1041
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1042
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1043
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1044
+ """
1045
+
1046
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1047
+
1048
+ outputs = self.deberta(
1049
+ input_ids,
1050
+ attention_mask=attention_mask,
1051
+ token_type_ids=token_type_ids,
1052
+ position_ids=position_ids,
1053
+ inputs_embeds=inputs_embeds,
1054
+ output_attentions=output_attentions,
1055
+ output_hidden_states=output_hidden_states,
1056
+ return_dict=return_dict,
1057
+ )
1058
+
1059
+ sequence_output = outputs[0]
1060
+ prediction_scores = self.cls(sequence_output)
1061
+
1062
+ masked_lm_loss = None
1063
+ if labels is not None:
1064
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1065
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1066
+
1067
+ if not return_dict:
1068
+ output = (prediction_scores,) + outputs[1:]
1069
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1070
+
1071
+ return MaskedLMOutput(
1072
+ loss=masked_lm_loss,
1073
+ logits=prediction_scores,
1074
+ hidden_states=outputs.hidden_states,
1075
+ attentions=outputs.attentions,
1076
+ )
1077
+
1078
+
1079
+ # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
1080
+ class DebertaPredictionHeadTransform(nn.Module):
1081
+ def __init__(self, config):
1082
+ super().__init__()
1083
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1084
+ if isinstance(config.hidden_act, str):
1085
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1086
+ else:
1087
+ self.transform_act_fn = config.hidden_act
1088
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1089
+
1090
+ def forward(self, hidden_states):
1091
+ hidden_states = self.dense(hidden_states)
1092
+ hidden_states = self.transform_act_fn(hidden_states)
1093
+ hidden_states = self.LayerNorm(hidden_states)
1094
+ return hidden_states
1095
+
1096
+
1097
+ # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
1098
+ class DebertaLMPredictionHead(nn.Module):
1099
+ def __init__(self, config):
1100
+ super().__init__()
1101
+ self.transform = DebertaPredictionHeadTransform(config)
1102
+
1103
+ # The output weights are the same as the input embeddings, but there is
1104
+ # an output-only bias for each token.
1105
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1106
+
1107
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1108
+
1109
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1110
+ self.decoder.bias = self.bias
1111
+
1112
+ def forward(self, hidden_states):
1113
+ hidden_states = self.transform(hidden_states)
1114
+ hidden_states = self.decoder(hidden_states)
1115
+ return hidden_states
1116
+
1117
+
1118
+ # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
1119
+ class DebertaOnlyMLMHead(nn.Module):
1120
+ def __init__(self, config):
1121
+ super().__init__()
1122
+ self.predictions = DebertaLMPredictionHead(config)
1123
+
1124
+ def forward(self, sequence_output):
1125
+ prediction_scores = self.predictions(sequence_output)
1126
+ return prediction_scores
1127
+
1128
+
1129
+ @add_start_docstrings(
1130
+ """
1131
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1132
+ pooled output) e.g. for GLUE tasks.
1133
+ """,
1134
+ DEBERTA_START_DOCSTRING,
1135
+ )
1136
+ class DebertaForSequenceClassification(DebertaPreTrainedModel):
1137
+ def __init__(self, config):
1138
+ super().__init__(config)
1139
+
1140
+ num_labels = getattr(config, "num_labels", 2)
1141
+ self.num_labels = num_labels
1142
+
1143
+ self.deberta = DebertaModel(config)
1144
+ self.pooler = ContextPooler(config)
1145
+ output_dim = self.pooler.output_dim
1146
+
1147
+ self.classifier = nn.Linear(output_dim, num_labels)
1148
+ drop_out = getattr(config, "cls_dropout", None)
1149
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1150
+ self.dropout = StableDropout(drop_out)
1151
+
1152
+ self.init_weights()
1153
+
1154
+ def get_input_embeddings(self):
1155
+ return self.deberta.get_input_embeddings()
1156
+
1157
+ def set_input_embeddings(self, new_embeddings):
1158
+ self.deberta.set_input_embeddings(new_embeddings)
1159
+
1160
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1161
+ def forward(
1162
+ self,
1163
+ input_ids=None,
1164
+ attention_mask=None,
1165
+ token_type_ids=None,
1166
+ position_ids=None,
1167
+ inputs_embeds=None,
1168
+ labels=None,
1169
+ output_attentions=None,
1170
+ output_hidden_states=None,
1171
+ return_dict=None,
1172
+ ):
1173
+ r"""
1174
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1175
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1176
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1177
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1178
+ """
1179
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1180
+
1181
+ outputs = self.deberta(
1182
+ input_ids,
1183
+ token_type_ids=token_type_ids,
1184
+ attention_mask=attention_mask,
1185
+ position_ids=position_ids,
1186
+ inputs_embeds=inputs_embeds,
1187
+ output_attentions=output_attentions,
1188
+ output_hidden_states=output_hidden_states,
1189
+ return_dict=return_dict,
1190
+ )
1191
+
1192
+ encoder_layer = outputs[0]
1193
+ pooled_output = self.pooler(encoder_layer)
1194
+ pooled_output = self.dropout(pooled_output)
1195
+ logits = self.classifier(pooled_output)
1196
+
1197
+ loss = None
1198
+ if labels is not None:
1199
+ if self.num_labels == 1:
1200
+ # regression task
1201
+ loss_fn = nn.MSELoss()
1202
+ logits = logits.view(-1).to(labels.dtype)
1203
+ loss = loss_fn(logits, labels.view(-1))
1204
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1205
+ label_index = (labels >= 0).nonzero()
1206
+ labels = labels.long()
1207
+ if label_index.size(0) > 0:
1208
+ labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
1209
+ labels = torch.gather(labels, 0, label_index.view(-1))
1210
+ loss_fct = CrossEntropyLoss()
1211
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1212
+ else:
1213
+ loss = torch.tensor(0).to(logits)
1214
+ else:
1215
+ log_softmax = nn.LogSoftmax(-1)
1216
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1217
+ if not return_dict:
1218
+ output = (logits,) + outputs[1:]
1219
+ return ((loss,) + output) if loss is not None else output
1220
+ else:
1221
+ return SequenceClassifierOutput(
1222
+ loss=loss,
1223
+ logits=logits,
1224
+ hidden_states=outputs.hidden_states,
1225
+ attentions=outputs.attentions,
1226
+ )
1227
+
1228
+
1229
+ @add_start_docstrings(
1230
+ """
1231
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1232
+ Named-Entity-Recognition (NER) tasks.
1233
+ """,
1234
+ DEBERTA_START_DOCSTRING,
1235
+ )
1236
+ class DebertaForTokenClassification(DebertaPreTrainedModel):
1237
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1238
+
1239
+ def __init__(self, config):
1240
+ super().__init__(config)
1241
+ self.num_labels = config.num_labels
1242
+
1243
+ self.deberta = DebertaModel(config)
1244
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1245
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1246
+
1247
+ for param in self.deberta.parameters():
1248
+ param.requires_grad = False
1249
+
1250
+ self.init_weights()
1251
+
1252
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1253
+ def forward(
1254
+ self,
1255
+ input_ids=None,
1256
+ attention_mask=None,
1257
+ token_type_ids=None,
1258
+ position_ids=None,
1259
+ inputs_embeds=None,
1260
+ labels=None,
1261
+ output_attentions=None,
1262
+ output_hidden_states=None,
1263
+ return_dict=None,
1264
+ ):
1265
+ r"""
1266
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1267
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1268
+ 1]``.
1269
+ """
1270
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1271
+
1272
+ outputs = self.deberta(
1273
+ input_ids,
1274
+ attention_mask=attention_mask,
1275
+ token_type_ids=token_type_ids,
1276
+ position_ids=position_ids,
1277
+ inputs_embeds=inputs_embeds,
1278
+ output_attentions=output_attentions,
1279
+ output_hidden_states=output_hidden_states,
1280
+ return_dict=return_dict,
1281
+ )
1282
+
1283
+ sequence_output = outputs[0]
1284
+
1285
+ sequence_output = self.dropout(sequence_output)
1286
+ logits = self.classifier(sequence_output)
1287
+
1288
+ loss = None
1289
+ if labels is not None:
1290
+ loss_fct = CrossEntropyLoss()
1291
+ # Only keep active parts of the loss
1292
+ if attention_mask is not None:
1293
+ active_loss = attention_mask.view(-1) == 1
1294
+ active_logits = logits.view(-1, self.num_labels)
1295
+ active_labels = torch.where(
1296
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1297
+ )
1298
+ loss = loss_fct(active_logits, active_labels)
1299
+ else:
1300
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1301
+
1302
+ if not return_dict:
1303
+ output = (logits,) + outputs[1:]
1304
+ return ((loss,) + output) if loss is not None else output
1305
+
1306
+ return TokenClassifierOutput(
1307
+ loss=loss,
1308
+ logits=logits,
1309
+ hidden_states=outputs.hidden_states,
1310
+ attentions=outputs.attentions,
1311
+ )
1312
+
1313
+
1314
+ @add_start_docstrings(
1315
+ """
1316
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1317
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1318
+ """,
1319
+ DEBERTA_START_DOCSTRING,
1320
+ )
1321
+ class DebertaForQuestionAnswering(DebertaPreTrainedModel):
1322
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1323
+
1324
+ def __init__(self, config):
1325
+ super().__init__(config)
1326
+ self.num_labels = config.num_labels
1327
+
1328
+ self.deberta = DebertaModel(config)
1329
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1330
+
1331
+ self.init_weights()
1332
+
1333
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1334
+ def forward(
1335
+ self,
1336
+ input_ids=None,
1337
+ attention_mask=None,
1338
+ token_type_ids=None,
1339
+ position_ids=None,
1340
+ inputs_embeds=None,
1341
+ start_positions=None,
1342
+ end_positions=None,
1343
+ output_attentions=None,
1344
+ output_hidden_states=None,
1345
+ return_dict=None,
1346
+ ):
1347
+ r"""
1348
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1349
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1350
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1351
+ sequence are not taken into account for computing the loss.
1352
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1353
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1354
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1355
+ sequence are not taken into account for computing the loss.
1356
+ """
1357
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1358
+
1359
+ outputs = self.deberta(
1360
+ input_ids,
1361
+ attention_mask=attention_mask,
1362
+ token_type_ids=token_type_ids,
1363
+ position_ids=position_ids,
1364
+ inputs_embeds=inputs_embeds,
1365
+ output_attentions=output_attentions,
1366
+ output_hidden_states=output_hidden_states,
1367
+ return_dict=return_dict,
1368
+ )
1369
+
1370
+ sequence_output = outputs[0]
1371
+
1372
+ logits = self.qa_outputs(sequence_output)
1373
+ start_logits, end_logits = logits.split(1, dim=-1)
1374
+ start_logits = start_logits.squeeze(-1).contiguous()
1375
+ end_logits = end_logits.squeeze(-1).contiguous()
1376
+
1377
+ total_loss = None
1378
+ if start_positions is not None and end_positions is not None:
1379
+ # If we are on multi-GPU, split add a dimension
1380
+ if len(start_positions.size()) > 1:
1381
+ start_positions = start_positions.squeeze(-1)
1382
+ if len(end_positions.size()) > 1:
1383
+ end_positions = end_positions.squeeze(-1)
1384
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1385
+ ignored_index = start_logits.size(1)
1386
+ start_positions = start_positions.clamp(0, ignored_index)
1387
+ end_positions = end_positions.clamp(0, ignored_index)
1388
+
1389
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1390
+ start_loss = loss_fct(start_logits, start_positions)
1391
+ end_loss = loss_fct(end_logits, end_positions)
1392
+ total_loss = (start_loss + end_loss) / 2
1393
+
1394
+ if not return_dict:
1395
+ output = (start_logits, end_logits) + outputs[1:]
1396
+ return ((total_loss,) + output) if total_loss is not None else output
1397
+
1398
+ return QuestionAnsweringModelOutput(
1399
+ loss=total_loss,
1400
+ start_logits=start_logits,
1401
+ end_logits=end_logits,
1402
+ hidden_states=outputs.hidden_states,
1403
+ attentions=outputs.attentions,
1404
+ )
soft_prompt/model/debertaV2.py ADDED
@@ -0,0 +1,1509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2020 Microsoft and the Hugging Face Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch DeBERTa-v2 model. """
16
+
17
+ import math
18
+ from collections.abc import Sequence
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import _softmax_backward_data, nn
23
+ from torch.nn import CrossEntropyLoss, LayerNorm
24
+
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutput,
30
+ MaskedLMOutput,
31
+ QuestionAnsweringModelOutput,
32
+ SequenceClassifierOutput,
33
+ TokenClassifierOutput,
34
+ )
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging
37
+ from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
38
+
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ _CONFIG_FOR_DOC = "DebertaV2Config"
43
+ _TOKENIZER_FOR_DOC = "DebertaV2Tokenizer"
44
+ _CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
45
+
46
+ DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
47
+ "microsoft/deberta-v2-xlarge",
48
+ "microsoft/deberta-v2-xxlarge",
49
+ "microsoft/deberta-v2-xlarge-mnli",
50
+ "microsoft/deberta-v2-xxlarge-mnli",
51
+ ]
52
+
53
+
54
+ # Copied from transformers.models.deberta.modeling_deberta.ContextPooler
55
+ class ContextPooler(nn.Module):
56
+ def __init__(self, config):
57
+ super().__init__()
58
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
59
+ self.dropout = StableDropout(config.pooler_dropout)
60
+ self.config = config
61
+
62
+ def forward(self, hidden_states):
63
+ # We "pool" the model by simply taking the hidden state corresponding
64
+ # to the first token.
65
+
66
+ context_token = hidden_states[:, 0]
67
+ context_token = self.dropout(context_token)
68
+ pooled_output = self.dense(context_token)
69
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
70
+ return pooled_output
71
+
72
+ @property
73
+ def output_dim(self):
74
+ return self.config.hidden_size
75
+
76
+
77
+ # Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
78
+ class XSoftmax(torch.autograd.Function):
79
+ """
80
+ Masked Softmax which is optimized for saving memory
81
+ Args:
82
+ input (:obj:`torch.tensor`): The input tensor that will apply softmax.
83
+ mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
84
+ dim (int): The dimension that will apply softmax
85
+ Example::
86
+ >>> import torch
87
+ >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
88
+ >>> # Make a tensor
89
+ >>> x = torch.randn([4,20,100])
90
+ >>> # Create a mask
91
+ >>> mask = (x>0).int()
92
+ >>> y = XSoftmax.apply(x, mask, dim=-1)
93
+ """
94
+
95
+ @staticmethod
96
+ def forward(self, input, mask, dim):
97
+ self.dim = dim
98
+ rmask = ~(mask.bool())
99
+
100
+ output = input.masked_fill(rmask, float("-inf"))
101
+ output = torch.softmax(output, self.dim)
102
+ output.masked_fill_(rmask, 0)
103
+ self.save_for_backward(output)
104
+ return output
105
+
106
+ @staticmethod
107
+ def backward(self, grad_output):
108
+ (output,) = self.saved_tensors
109
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
110
+ return inputGrad, None, None
111
+
112
+
113
+ # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
114
+ class DropoutContext(object):
115
+ def __init__(self):
116
+ self.dropout = 0
117
+ self.mask = None
118
+ self.scale = 1
119
+ self.reuse_mask = True
120
+
121
+
122
+ # Copied from transformers.models.deberta.modeling_deberta.get_mask
123
+ def get_mask(input, local_context):
124
+ if not isinstance(local_context, DropoutContext):
125
+ dropout = local_context
126
+ mask = None
127
+ else:
128
+ dropout = local_context.dropout
129
+ dropout *= local_context.scale
130
+ mask = local_context.mask if local_context.reuse_mask else None
131
+
132
+ if dropout > 0 and mask is None:
133
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
134
+
135
+ if isinstance(local_context, DropoutContext):
136
+ if local_context.mask is None:
137
+ local_context.mask = mask
138
+
139
+ return mask, dropout
140
+
141
+
142
+ # Copied from transformers.models.deberta.modeling_deberta.XDropout
143
+ class XDropout(torch.autograd.Function):
144
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
145
+
146
+ @staticmethod
147
+ def forward(ctx, input, local_ctx):
148
+ mask, dropout = get_mask(input, local_ctx)
149
+ ctx.scale = 1.0 / (1 - dropout)
150
+ if dropout > 0:
151
+ ctx.save_for_backward(mask)
152
+ return input.masked_fill(mask, 0) * ctx.scale
153
+ else:
154
+ return input
155
+
156
+ @staticmethod
157
+ def backward(ctx, grad_output):
158
+ if ctx.scale > 1:
159
+ (mask,) = ctx.saved_tensors
160
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
161
+ else:
162
+ return grad_output, None
163
+
164
+
165
+ # Copied from transformers.models.deberta.modeling_deberta.StableDropout
166
+ class StableDropout(nn.Module):
167
+ """
168
+ Optimized dropout module for stabilizing the training
169
+ Args:
170
+ drop_prob (float): the dropout probabilities
171
+ """
172
+
173
+ def __init__(self, drop_prob):
174
+ super().__init__()
175
+ self.drop_prob = drop_prob
176
+ self.count = 0
177
+ self.context_stack = None
178
+
179
+ def forward(self, x):
180
+ """
181
+ Call the module
182
+ Args:
183
+ x (:obj:`torch.tensor`): The input tensor to apply dropout
184
+ """
185
+ if self.training and self.drop_prob > 0:
186
+ return XDropout.apply(x, self.get_context())
187
+ return x
188
+
189
+ def clear_context(self):
190
+ self.count = 0
191
+ self.context_stack = None
192
+
193
+ def init_context(self, reuse_mask=True, scale=1):
194
+ if self.context_stack is None:
195
+ self.context_stack = []
196
+ self.count = 0
197
+ for c in self.context_stack:
198
+ c.reuse_mask = reuse_mask
199
+ c.scale = scale
200
+
201
+ def get_context(self):
202
+ if self.context_stack is not None:
203
+ if self.count >= len(self.context_stack):
204
+ self.context_stack.append(DropoutContext())
205
+ ctx = self.context_stack[self.count]
206
+ ctx.dropout = self.drop_prob
207
+ self.count += 1
208
+ return ctx
209
+ else:
210
+ return self.drop_prob
211
+
212
+
213
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
214
+ class DebertaV2SelfOutput(nn.Module):
215
+ def __init__(self, config):
216
+ super().__init__()
217
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
218
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
219
+ self.dropout = StableDropout(config.hidden_dropout_prob)
220
+
221
+ def forward(self, hidden_states, input_tensor):
222
+ hidden_states = self.dense(hidden_states)
223
+ hidden_states = self.dropout(hidden_states)
224
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
225
+ return hidden_states
226
+
227
+
228
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
229
+ class DebertaV2Attention(nn.Module):
230
+ def __init__(self, config):
231
+ super().__init__()
232
+ self.self = DisentangledSelfAttention(config)
233
+ self.output = DebertaV2SelfOutput(config)
234
+ self.config = config
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ attention_mask,
240
+ return_att=False,
241
+ query_states=None,
242
+ relative_pos=None,
243
+ rel_embeddings=None,
244
+ past_key_value=None,
245
+ ):
246
+ self_output = self.self(
247
+ hidden_states,
248
+ attention_mask,
249
+ return_att,
250
+ query_states=query_states,
251
+ relative_pos=relative_pos,
252
+ rel_embeddings=rel_embeddings,
253
+ past_key_value=past_key_value,
254
+ )
255
+ if return_att:
256
+ self_output, att_matrix = self_output
257
+ if query_states is None:
258
+ query_states = hidden_states
259
+ attention_output = self.output(self_output, query_states)
260
+
261
+ if return_att:
262
+ return (attention_output, att_matrix)
263
+ else:
264
+ return attention_output
265
+
266
+
267
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
268
+ class DebertaV2Intermediate(nn.Module):
269
+ def __init__(self, config):
270
+ super().__init__()
271
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
272
+ if isinstance(config.hidden_act, str):
273
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
274
+ else:
275
+ self.intermediate_act_fn = config.hidden_act
276
+
277
+ def forward(self, hidden_states):
278
+ hidden_states = self.dense(hidden_states)
279
+ hidden_states = self.intermediate_act_fn(hidden_states)
280
+ return hidden_states
281
+
282
+
283
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
284
+ class DebertaV2Output(nn.Module):
285
+ def __init__(self, config):
286
+ super().__init__()
287
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
288
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
289
+ self.dropout = StableDropout(config.hidden_dropout_prob)
290
+ self.config = config
291
+
292
+ def forward(self, hidden_states, input_tensor):
293
+ hidden_states = self.dense(hidden_states)
294
+ hidden_states = self.dropout(hidden_states)
295
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
296
+ return hidden_states
297
+
298
+
299
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
300
+ class DebertaV2Layer(nn.Module):
301
+ def __init__(self, config):
302
+ super().__init__()
303
+ self.attention = DebertaV2Attention(config)
304
+ self.intermediate = DebertaV2Intermediate(config)
305
+ self.output = DebertaV2Output(config)
306
+
307
+ def forward(
308
+ self,
309
+ hidden_states,
310
+ attention_mask,
311
+ return_att=False,
312
+ query_states=None,
313
+ relative_pos=None,
314
+ rel_embeddings=None,
315
+ past_key_value=None,
316
+ ):
317
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
318
+ attention_output = self.attention(
319
+ hidden_states,
320
+ attention_mask,
321
+ return_att=return_att,
322
+ query_states=query_states,
323
+ relative_pos=relative_pos,
324
+ rel_embeddings=rel_embeddings,
325
+ past_key_value=self_attn_past_key_value,
326
+ )
327
+ if return_att:
328
+ attention_output, att_matrix = attention_output
329
+ intermediate_output = self.intermediate(attention_output)
330
+ layer_output = self.output(intermediate_output, attention_output)
331
+ if return_att:
332
+ return (layer_output, att_matrix)
333
+ else:
334
+ return layer_output
335
+
336
+
337
+ class ConvLayer(nn.Module):
338
+ def __init__(self, config):
339
+ super().__init__()
340
+ kernel_size = getattr(config, "conv_kernel_size", 3)
341
+ groups = getattr(config, "conv_groups", 1)
342
+ self.conv_act = getattr(config, "conv_act", "tanh")
343
+ self.conv = nn.Conv1d(
344
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
345
+ )
346
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
347
+ self.dropout = StableDropout(config.hidden_dropout_prob)
348
+ self.config = config
349
+
350
+ def forward(self, hidden_states, residual_states, input_mask):
351
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
352
+ rmask = (1 - input_mask).bool()
353
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
354
+ out = ACT2FN[self.conv_act](self.dropout(out))
355
+
356
+ layer_norm_input = residual_states + out
357
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
358
+
359
+ if input_mask is None:
360
+ output_states = output
361
+ else:
362
+ if input_mask.dim() != layer_norm_input.dim():
363
+ if input_mask.dim() == 4:
364
+ input_mask = input_mask.squeeze(1).squeeze(1)
365
+ input_mask = input_mask.unsqueeze(2)
366
+
367
+ input_mask = input_mask.to(output.dtype)
368
+ output_states = output * input_mask
369
+
370
+ return output_states
371
+
372
+
373
+ class DebertaV2Encoder(nn.Module):
374
+ """Modified BertEncoder with relative position bias support"""
375
+
376
+ def __init__(self, config):
377
+ super().__init__()
378
+
379
+ self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
380
+ self.relative_attention = getattr(config, "relative_attention", False)
381
+
382
+ if self.relative_attention:
383
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
384
+ if self.max_relative_positions < 1:
385
+ self.max_relative_positions = config.max_position_embeddings
386
+
387
+ self.position_buckets = getattr(config, "position_buckets", -1)
388
+ pos_ebd_size = self.max_relative_positions * 2
389
+
390
+ if self.position_buckets > 0:
391
+ pos_ebd_size = self.position_buckets * 2
392
+
393
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
394
+
395
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
396
+
397
+ if "layer_norm" in self.norm_rel_ebd:
398
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
399
+
400
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
401
+
402
+ def get_rel_embedding(self):
403
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
404
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
405
+ rel_embeddings = self.LayerNorm(rel_embeddings)
406
+ return rel_embeddings
407
+
408
+ def get_attention_mask(self, attention_mask):
409
+ if attention_mask.dim() <= 2:
410
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
411
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
412
+ attention_mask = attention_mask.byte()
413
+ elif attention_mask.dim() == 3:
414
+ attention_mask = attention_mask.unsqueeze(1)
415
+
416
+ return attention_mask
417
+
418
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
419
+ if self.relative_attention and relative_pos is None:
420
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
421
+ relative_pos = build_relative_position(
422
+ q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
423
+ )
424
+ return relative_pos
425
+
426
+ def forward(
427
+ self,
428
+ hidden_states,
429
+ attention_mask,
430
+ output_hidden_states=True,
431
+ output_attentions=False,
432
+ query_states=None,
433
+ relative_pos=None,
434
+ return_dict=True,
435
+ past_key_values=None,
436
+ ):
437
+ if attention_mask.dim() <= 2:
438
+ input_mask = attention_mask
439
+ else:
440
+ input_mask = (attention_mask.sum(-2) > 0).byte()
441
+ attention_mask = self.get_attention_mask(attention_mask)
442
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
443
+
444
+ all_hidden_states = () if output_hidden_states else None
445
+ all_attentions = () if output_attentions else None
446
+
447
+ if isinstance(hidden_states, Sequence): # False
448
+ next_kv = hidden_states[0]
449
+ else:
450
+ next_kv = hidden_states
451
+ rel_embeddings = self.get_rel_embedding()
452
+ output_states = next_kv
453
+ for i, layer_module in enumerate(self.layer):
454
+
455
+ if output_hidden_states:
456
+ all_hidden_states = all_hidden_states + (output_states,)
457
+
458
+ past_key_value = past_key_values[i] if past_key_values is not None else None
459
+
460
+ output_states = layer_module(
461
+ next_kv,
462
+ attention_mask,
463
+ output_attentions,
464
+ query_states=query_states,
465
+ relative_pos=relative_pos,
466
+ rel_embeddings=rel_embeddings,
467
+ past_key_value=past_key_value,
468
+ )
469
+ if output_attentions:
470
+ output_states, att_m = output_states
471
+
472
+ if i == 0 and self.conv is not None:
473
+ if past_key_values is not None:
474
+ past_key_value_length = past_key_values[0][0].shape[2]
475
+ input_mask = input_mask[:, past_key_value_length:].contiguous()
476
+ output_states = self.conv(hidden_states, output_states, input_mask)
477
+
478
+ if query_states is not None:
479
+ query_states = output_states
480
+ if isinstance(hidden_states, Sequence):
481
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
482
+ else:
483
+ next_kv = output_states
484
+
485
+ if output_attentions:
486
+ all_attentions = all_attentions + (att_m,)
487
+
488
+ if output_hidden_states:
489
+ all_hidden_states = all_hidden_states + (output_states,)
490
+
491
+ if not return_dict:
492
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
493
+ return BaseModelOutput(
494
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
495
+ )
496
+
497
+
498
+ def make_log_bucket_position(relative_pos, bucket_size, max_position):
499
+ sign = np.sign(relative_pos)
500
+ mid = bucket_size // 2
501
+ abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
502
+ log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
503
+ bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
504
+ return bucket_pos
505
+
506
+
507
+ def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
508
+ """
509
+ Build relative position according to the query and key
510
+ We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key
511
+ :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} =
512
+ P_q - P_k`
513
+ Args:
514
+ query_size (int): the length of query
515
+ key_size (int): the length of key
516
+ bucket_size (int): the size of position bucket
517
+ max_position (int): the maximum allowed absolute position
518
+ Return:
519
+ :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
520
+ """
521
+ q_ids = np.arange(0, query_size)
522
+ k_ids = np.arange(0, key_size)
523
+ rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
524
+ if bucket_size > 0 and max_position > 0:
525
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
526
+ rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
527
+ rel_pos_ids = rel_pos_ids[:query_size, :]
528
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
529
+ return rel_pos_ids
530
+
531
+
532
+ @torch.jit.script
533
+ # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
534
+ def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
535
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
536
+
537
+
538
+ @torch.jit.script
539
+ # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
540
+ def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
541
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
542
+
543
+
544
+ @torch.jit.script
545
+ # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
546
+ def pos_dynamic_expand(pos_index, p2c_att, key_layer):
547
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
548
+
549
+
550
+ class DisentangledSelfAttention(nn.Module):
551
+ """
552
+ Disentangled self-attention module
553
+ Parameters:
554
+ config (:obj:`DebertaV2Config`):
555
+ A model config class instance with the configuration to build a new model. The schema is similar to
556
+ `BertConfig`, for more details, please refer :class:`~transformers.DebertaV2Config`
557
+ """
558
+
559
+ def __init__(self, config):
560
+ super().__init__()
561
+ if config.hidden_size % config.num_attention_heads != 0:
562
+ raise ValueError(
563
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
564
+ f"heads ({config.num_attention_heads})"
565
+ )
566
+ self.num_attention_heads = config.num_attention_heads
567
+ _attention_head_size = config.hidden_size // config.num_attention_heads
568
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
569
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
570
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
571
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
572
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
573
+
574
+ self.share_att_key = getattr(config, "share_att_key", False)
575
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
576
+ self.relative_attention = getattr(config, "relative_attention", False)
577
+
578
+ if self.relative_attention:
579
+ self.position_buckets = getattr(config, "position_buckets", -1)
580
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
581
+ if self.max_relative_positions < 1:
582
+ self.max_relative_positions = config.max_position_embeddings
583
+ self.pos_ebd_size = self.max_relative_positions
584
+ if self.position_buckets > 0:
585
+ self.pos_ebd_size = self.position_buckets
586
+
587
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
588
+
589
+ if not self.share_att_key:
590
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
591
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
592
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
593
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
594
+
595
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
596
+
597
+ def transpose_for_scores(self, x, attention_heads, past_key_value=None):
598
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
599
+ x = x.view(*new_x_shape)
600
+ x = x.permute(0, 2, 1, 3)
601
+ if past_key_value is not None:
602
+ x = torch.cat([past_key_value, x], dim=2)
603
+ new_x_shape = x.shape
604
+ return x.contiguous().view(-1, new_x_shape[2], new_x_shape[-1])
605
+ # return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
606
+
607
+ def forward(
608
+ self,
609
+ hidden_states,
610
+ attention_mask,
611
+ return_att=False,
612
+ query_states=None,
613
+ relative_pos=None,
614
+ rel_embeddings=None,
615
+ past_key_value=None,
616
+ ):
617
+ """
618
+ Call the module
619
+ Args:
620
+ hidden_states (:obj:`torch.FloatTensor`):
621
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
622
+ `Attention(Q,K,V)`
623
+ attention_mask (:obj:`torch.ByteTensor`):
624
+ An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum
625
+ sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
626
+ th token.
627
+ return_att (:obj:`bool`, optional):
628
+ Whether return the attention matrix.
629
+ query_states (:obj:`torch.FloatTensor`, optional):
630
+ The `Q` state in `Attention(Q,K,V)`.
631
+ relative_pos (:obj:`torch.LongTensor`):
632
+ The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with
633
+ values ranging in [`-max_relative_positions`, `max_relative_positions`].
634
+ rel_embeddings (:obj:`torch.FloatTensor`):
635
+ The embedding of relative distances. It's a tensor of shape [:math:`2 \\times
636
+ \\text{max_relative_positions}`, `hidden_size`].
637
+ """
638
+ if query_states is None:
639
+ query_states = hidden_states
640
+
641
+ past_key_value_length = past_key_value.shape[3] if past_key_value is not None else 0
642
+ if past_key_value is not None:
643
+ key_layer_prefix = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[0])
644
+ # value_layer_prefix = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[1])
645
+
646
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
647
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
648
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[1])
649
+
650
+ rel_att = None
651
+ # Take the dot product between "query" and "key" to get the raw attention scores.
652
+ scale_factor = 1
653
+ if "c2p" in self.pos_att_type:
654
+ scale_factor += 1
655
+ if "p2c" in self.pos_att_type:
656
+ scale_factor += 1
657
+ if "p2p" in self.pos_att_type:
658
+ scale_factor += 1
659
+ scale = math.sqrt(query_layer.size(-1) * scale_factor)
660
+ # attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
661
+ attention_scores = torch.bmm(query_layer, key_layer_prefix.transpose(-1, -2)) / scale
662
+
663
+ if self.relative_attention:
664
+ rel_embeddings = self.pos_dropout(rel_embeddings)
665
+ rel_att = self.disentangled_attention_bias(
666
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
667
+ )
668
+
669
+ if rel_att is not None:
670
+ if past_key_value is not None:
671
+ att_shape = rel_att.shape[:-1] + (past_key_value_length,)
672
+ prefix_att = torch.zeros(*att_shape).to(rel_att.device)
673
+ attention_scores = attention_scores + torch.cat([prefix_att, rel_att], dim=-1)
674
+ else:
675
+ attention_scores = attention_scores + rel_att
676
+ # print(attention_scores.shape)
677
+ attention_scores = attention_scores
678
+ attention_scores = attention_scores.view(
679
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
680
+ )
681
+
682
+ # bsz x height x length x dimension
683
+ attention_mask = attention_mask[:,:, past_key_value_length:,:]
684
+
685
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
686
+ attention_probs = self.dropout(attention_probs)
687
+
688
+ context_layer = torch.bmm(
689
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
690
+ )
691
+ context_layer = (
692
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
693
+ .permute(0, 2, 1, 3)
694
+ .contiguous()
695
+ )
696
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
697
+ context_layer = context_layer.view(*new_context_layer_shape)
698
+ if return_att:
699
+ return (context_layer, attention_probs)
700
+ else:
701
+ return context_layer
702
+
703
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
704
+ if relative_pos is None:
705
+ q = query_layer.size(-2)
706
+ relative_pos = build_relative_position(
707
+ q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
708
+ )
709
+ if relative_pos.dim() == 2:
710
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
711
+ elif relative_pos.dim() == 3:
712
+ relative_pos = relative_pos.unsqueeze(1)
713
+ # bsz x height x query x key
714
+ elif relative_pos.dim() != 4:
715
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
716
+
717
+ att_span = self.pos_ebd_size
718
+ relative_pos = relative_pos.long().to(query_layer.device)
719
+
720
+ rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0)
721
+ if self.share_att_key: # True
722
+ pos_query_layer = self.transpose_for_scores(
723
+ self.query_proj(rel_embeddings), self.num_attention_heads
724
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
725
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
726
+ query_layer.size(0) // self.num_attention_heads, 1, 1
727
+ )
728
+ else:
729
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
730
+ pos_key_layer = self.transpose_for_scores(
731
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
732
+ ).repeat(
733
+ query_layer.size(0) // self.num_attention_heads, 1, 1
734
+ ) # .split(self.all_head_size, dim=-1)
735
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
736
+ pos_query_layer = self.transpose_for_scores(
737
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
738
+ ).repeat(
739
+ query_layer.size(0) // self.num_attention_heads, 1, 1
740
+ ) # .split(self.all_head_size, dim=-1)
741
+
742
+ score = 0
743
+ # content->position
744
+ if "c2p" in self.pos_att_type:
745
+ scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
746
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
747
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
748
+ c2p_att = torch.gather(
749
+ c2p_att,
750
+ dim=-1,
751
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
752
+ )
753
+ score += c2p_att / scale
754
+
755
+ # position->content
756
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
757
+ scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
758
+ if key_layer.size(-2) != query_layer.size(-2):
759
+ r_pos = build_relative_position(
760
+ key_layer.size(-2),
761
+ key_layer.size(-2),
762
+ bucket_size=self.position_buckets,
763
+ max_position=self.max_relative_positions,
764
+ ).to(query_layer.device)
765
+ r_pos = r_pos.unsqueeze(0)
766
+ else:
767
+ r_pos = relative_pos
768
+
769
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
770
+ if query_layer.size(-2) != key_layer.size(-2):
771
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
772
+
773
+ if "p2c" in self.pos_att_type:
774
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
775
+ p2c_att = torch.gather(
776
+ p2c_att,
777
+ dim=-1,
778
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
779
+ ).transpose(-1, -2)
780
+ if query_layer.size(-2) != key_layer.size(-2):
781
+ p2c_att = torch.gather(
782
+ p2c_att,
783
+ dim=-2,
784
+ index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))),
785
+ )
786
+ score += p2c_att / scale
787
+
788
+ # position->position
789
+ if "p2p" in self.pos_att_type:
790
+ pos_query = pos_query_layer[:, :, att_span:, :]
791
+ p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
792
+ p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
793
+ if query_layer.size(-2) != key_layer.size(-2):
794
+ p2p_att = torch.gather(
795
+ p2p_att,
796
+ dim=-2,
797
+ index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))),
798
+ )
799
+ p2p_att = torch.gather(
800
+ p2p_att,
801
+ dim=-1,
802
+ index=c2p_pos.expand(
803
+ [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
804
+ ),
805
+ )
806
+ score += p2p_att
807
+
808
+ return score
809
+
810
+
811
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm
812
+ class DebertaV2Embeddings(nn.Module):
813
+ """Construct the embeddings from word, position and token_type embeddings."""
814
+
815
+ def __init__(self, config):
816
+ super().__init__()
817
+ pad_token_id = getattr(config, "pad_token_id", 0)
818
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
819
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
820
+
821
+ self.position_biased_input = getattr(config, "position_biased_input", True)
822
+ if not self.position_biased_input:
823
+ self.position_embeddings = None
824
+ else:
825
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
826
+
827
+ if config.type_vocab_size > 0:
828
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
829
+
830
+ if self.embedding_size != config.hidden_size:
831
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
832
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
833
+ self.dropout = StableDropout(config.hidden_dropout_prob)
834
+ self.config = config
835
+
836
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
837
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
838
+
839
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, past_key_values_length=0,):
840
+ if input_ids is not None:
841
+ input_shape = input_ids.size()
842
+ else:
843
+ input_shape = inputs_embeds.size()[:-1]
844
+
845
+ seq_length = input_shape[1]
846
+
847
+ if position_ids is None:
848
+ # position_ids = self.position_ids[:, :seq_length]
849
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
850
+
851
+ if token_type_ids is None:
852
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
853
+
854
+ if inputs_embeds is None:
855
+ inputs_embeds = self.word_embeddings(input_ids)
856
+
857
+ if self.position_embeddings is not None:
858
+ position_embeddings = self.position_embeddings(position_ids.long())
859
+ else:
860
+ position_embeddings = torch.zeros_like(inputs_embeds)
861
+
862
+ embeddings = inputs_embeds
863
+ if self.position_biased_input:
864
+ embeddings += position_embeddings
865
+ if self.config.type_vocab_size > 0:
866
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
867
+ embeddings += token_type_embeddings
868
+
869
+ if self.embedding_size != self.config.hidden_size:
870
+ embeddings = self.embed_proj(embeddings)
871
+
872
+ embeddings = self.LayerNorm(embeddings)
873
+
874
+ if mask is not None:
875
+ if mask.dim() != embeddings.dim():
876
+ if mask.dim() == 4:
877
+ mask = mask.squeeze(1).squeeze(1)
878
+ mask = mask.unsqueeze(2)
879
+ mask = mask.to(embeddings.dtype)
880
+
881
+ embeddings = embeddings * mask
882
+
883
+ embeddings = self.dropout(embeddings)
884
+ return embeddings
885
+
886
+
887
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2
888
+ class DebertaV2PreTrainedModel(PreTrainedModel):
889
+ """
890
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
891
+ models.
892
+ """
893
+
894
+ config_class = DebertaV2Config
895
+ base_model_prefix = "deberta"
896
+ _keys_to_ignore_on_load_missing = ["position_ids"]
897
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
898
+
899
+ def __init__(self, config):
900
+ super().__init__(config)
901
+ self._register_load_state_dict_pre_hook(self._pre_load_hook)
902
+
903
+ def _init_weights(self, module):
904
+ """Initialize the weights."""
905
+ if isinstance(module, nn.Linear):
906
+ # Slightly different from the TF version which uses truncated_normal for initialization
907
+ # cf https://github.com/pytorch/pytorch/pull/5617
908
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
909
+ if module.bias is not None:
910
+ module.bias.data.zero_()
911
+ elif isinstance(module, nn.Embedding):
912
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
913
+ if module.padding_idx is not None:
914
+ module.weight.data[module.padding_idx].zero_()
915
+
916
+ def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
917
+ """
918
+ Removes the classifier if it doesn't have the correct number of labels.
919
+ """
920
+ self_state = self.state_dict()
921
+ if (
922
+ ("classifier.weight" in self_state)
923
+ and ("classifier.weight" in state_dict)
924
+ and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
925
+ ):
926
+ logger.warning(
927
+ f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
928
+ f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
929
+ f"weights. You should train your model on new data."
930
+ )
931
+ del state_dict["classifier.weight"]
932
+ if "classifier.bias" in state_dict:
933
+ del state_dict["classifier.bias"]
934
+
935
+
936
+ DEBERTA_START_DOCSTRING = r"""
937
+ The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
938
+ <https://arxiv.org/abs/2006.03654>`_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of
939
+ BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
940
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
941
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
942
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
943
+ general usage and behavior.```
944
+ Parameters:
945
+ config (:class:`~transformers.DebertaV2Config`): Model configuration class with all the parameters of the model.
946
+ Initializing with a config file does not load the weights associated with the model, only the
947
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
948
+ weights.
949
+ """
950
+
951
+ DEBERTA_INPUTS_DOCSTRING = r"""
952
+ Args:
953
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
954
+ Indices of input sequence tokens in the vocabulary.
955
+ Indices can be obtained using :class:`transformers.DebertaV2Tokenizer`. See
956
+ :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
957
+ details.
958
+ `What are input IDs? <../glossary.html#input-ids>`__
959
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
960
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
961
+ - 1 for tokens that are **not masked**,
962
+ - 0 for tokens that are **masked**.
963
+ `What are attention masks? <../glossary.html#attention-mask>`__
964
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
965
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
966
+ 1]``:
967
+ - 0 corresponds to a `sentence A` token,
968
+ - 1 corresponds to a `sentence B` token.
969
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
970
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
971
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
972
+ config.max_position_embeddings - 1]``.
973
+ `What are position IDs? <../glossary.html#position-ids>`_
974
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
975
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
976
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
977
+ than the model's internal embedding lookup matrix.
978
+ output_attentions (:obj:`bool`, `optional`):
979
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
980
+ tensors for more detail.
981
+ output_hidden_states (:obj:`bool`, `optional`):
982
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
983
+ more detail.
984
+ return_dict (:obj:`bool`, `optional`):
985
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
986
+ """
987
+
988
+
989
+ @add_start_docstrings(
990
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
991
+ DEBERTA_START_DOCSTRING,
992
+ )
993
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
994
+ class DebertaV2Model(DebertaV2PreTrainedModel):
995
+ def __init__(self, config):
996
+ super().__init__(config)
997
+
998
+ self.embeddings = DebertaV2Embeddings(config)
999
+ self.encoder = DebertaV2Encoder(config)
1000
+ self.z_steps = 0
1001
+ self.config = config
1002
+ self.init_weights()
1003
+
1004
+ def get_input_embeddings(self):
1005
+ return self.embeddings.word_embeddings
1006
+
1007
+ def set_input_embeddings(self, new_embeddings):
1008
+ self.embeddings.word_embeddings = new_embeddings
1009
+
1010
+ def _prune_heads(self, heads_to_prune):
1011
+ """
1012
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1013
+ class PreTrainedModel
1014
+ """
1015
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
1016
+
1017
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1018
+ def forward(
1019
+ self,
1020
+ input_ids=None,
1021
+ attention_mask=None,
1022
+ token_type_ids=None,
1023
+ position_ids=None,
1024
+ inputs_embeds=None,
1025
+ output_attentions=None,
1026
+ output_hidden_states=None,
1027
+ return_dict=None,
1028
+ past_key_values=None,
1029
+ ):
1030
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1031
+ output_hidden_states = (
1032
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1033
+ )
1034
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1035
+
1036
+ if input_ids is not None and inputs_embeds is not None:
1037
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1038
+ elif input_ids is not None:
1039
+ input_shape = input_ids.size()
1040
+ batch_size, seq_length = input_shape
1041
+ elif inputs_embeds is not None:
1042
+ input_shape = inputs_embeds.size()[:-1]
1043
+ batch_size, seq_length = input_shape
1044
+ else:
1045
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1046
+
1047
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1048
+
1049
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1050
+
1051
+ embedding_mask = torch.ones(input_shape, device=device)
1052
+ if attention_mask is None:
1053
+ # attention_mask = torch.ones(input_shape, device=device)
1054
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1055
+ if token_type_ids is None:
1056
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1057
+
1058
+ embedding_output = self.embeddings(
1059
+ input_ids=input_ids,
1060
+ token_type_ids=token_type_ids,
1061
+ position_ids=position_ids,
1062
+ # mask=attention_mask,
1063
+ mask=embedding_mask,
1064
+ inputs_embeds=inputs_embeds,
1065
+ past_key_values_length=past_key_values_length, # Ongoing
1066
+ )
1067
+
1068
+ encoder_outputs = self.encoder(
1069
+ embedding_output,
1070
+ attention_mask,
1071
+ output_hidden_states=True,
1072
+ output_attentions=output_attentions,
1073
+ return_dict=return_dict,
1074
+ past_key_values=past_key_values, # Ongoing
1075
+ )
1076
+ encoded_layers = encoder_outputs[1]
1077
+
1078
+ if self.z_steps > 1:
1079
+ hidden_states = encoded_layers[-2]
1080
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
1081
+ query_states = encoded_layers[-1]
1082
+ rel_embeddings = self.encoder.get_rel_embedding()
1083
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
1084
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
1085
+ for layer in layers[1:]:
1086
+ query_states = layer(
1087
+ hidden_states,
1088
+ attention_mask,
1089
+ return_att=False,
1090
+ query_states=query_states,
1091
+ relative_pos=rel_pos,
1092
+ rel_embeddings=rel_embeddings,
1093
+ )
1094
+ encoded_layers.append(query_states)
1095
+
1096
+ sequence_output = encoded_layers[-1]
1097
+
1098
+ if not return_dict:
1099
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
1100
+
1101
+ return BaseModelOutput(
1102
+ last_hidden_state=sequence_output,
1103
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
1104
+ attentions=encoder_outputs.attentions,
1105
+ )
1106
+
1107
+
1108
+ @add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING)
1109
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
1110
+ class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1111
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1112
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1113
+
1114
+ def __init__(self, config):
1115
+ super().__init__(config)
1116
+
1117
+ self.deberta = DebertaV2Model(config)
1118
+ self.cls = DebertaV2OnlyMLMHead(config)
1119
+
1120
+ self.init_weights()
1121
+
1122
+ def get_output_embeddings(self):
1123
+ return self.cls.predictions.decoder
1124
+
1125
+ def set_output_embeddings(self, new_embeddings):
1126
+ self.cls.predictions.decoder = new_embeddings
1127
+
1128
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1129
+ def forward(
1130
+ self,
1131
+ input_ids=None,
1132
+ attention_mask=None,
1133
+ token_type_ids=None,
1134
+ position_ids=None,
1135
+ inputs_embeds=None,
1136
+ labels=None,
1137
+ output_attentions=None,
1138
+ output_hidden_states=None,
1139
+ return_dict=None,
1140
+ ):
1141
+ r"""
1142
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1143
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1144
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1145
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1146
+ """
1147
+
1148
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1149
+
1150
+ outputs = self.deberta(
1151
+ input_ids,
1152
+ attention_mask=attention_mask,
1153
+ token_type_ids=token_type_ids,
1154
+ position_ids=position_ids,
1155
+ inputs_embeds=inputs_embeds,
1156
+ output_attentions=output_attentions,
1157
+ output_hidden_states=output_hidden_states,
1158
+ return_dict=return_dict,
1159
+ )
1160
+
1161
+ sequence_output = outputs[0]
1162
+ prediction_scores = self.cls(sequence_output)
1163
+
1164
+ masked_lm_loss = None
1165
+ if labels is not None:
1166
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1167
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1168
+
1169
+ if not return_dict:
1170
+ output = (prediction_scores,) + outputs[1:]
1171
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1172
+
1173
+ return MaskedLMOutput(
1174
+ loss=masked_lm_loss,
1175
+ logits=prediction_scores,
1176
+ hidden_states=outputs.hidden_states,
1177
+ attentions=outputs.attentions,
1178
+ )
1179
+
1180
+
1181
+ # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
1182
+ class DebertaV2PredictionHeadTransform(nn.Module):
1183
+ def __init__(self, config):
1184
+ super().__init__()
1185
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1186
+ if isinstance(config.hidden_act, str):
1187
+ self.transform_act_fn = ACT2FN[config.hidden_act]
1188
+ else:
1189
+ self.transform_act_fn = config.hidden_act
1190
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1191
+
1192
+ def forward(self, hidden_states):
1193
+ hidden_states = self.dense(hidden_states)
1194
+ hidden_states = self.transform_act_fn(hidden_states)
1195
+ hidden_states = self.LayerNorm(hidden_states)
1196
+ return hidden_states
1197
+
1198
+
1199
+ # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
1200
+ class DebertaV2LMPredictionHead(nn.Module):
1201
+ def __init__(self, config):
1202
+ super().__init__()
1203
+ self.transform = DebertaV2PredictionHeadTransform(config)
1204
+
1205
+ # The output weights are the same as the input embeddings, but there is
1206
+ # an output-only bias for each token.
1207
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1208
+
1209
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1210
+
1211
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1212
+ self.decoder.bias = self.bias
1213
+
1214
+ def forward(self, hidden_states):
1215
+ hidden_states = self.transform(hidden_states)
1216
+ hidden_states = self.decoder(hidden_states)
1217
+ return hidden_states
1218
+
1219
+
1220
+ # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
1221
+ class DebertaV2OnlyMLMHead(nn.Module):
1222
+ def __init__(self, config):
1223
+ super().__init__()
1224
+ self.predictions = DebertaV2LMPredictionHead(config)
1225
+
1226
+ def forward(self, sequence_output):
1227
+ prediction_scores = self.predictions(sequence_output)
1228
+ return prediction_scores
1229
+
1230
+
1231
+ @add_start_docstrings(
1232
+ """
1233
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1234
+ pooled output) e.g. for GLUE tasks.
1235
+ """,
1236
+ DEBERTA_START_DOCSTRING,
1237
+ )
1238
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
1239
+ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
1240
+ def __init__(self, config):
1241
+ super().__init__(config)
1242
+
1243
+ num_labels = getattr(config, "num_labels", 2)
1244
+ self.num_labels = num_labels
1245
+
1246
+ self.deberta = DebertaV2Model(config)
1247
+ self.pooler = ContextPooler(config)
1248
+ output_dim = self.pooler.output_dim
1249
+
1250
+ self.classifier = nn.Linear(output_dim, num_labels)
1251
+ drop_out = getattr(config, "cls_dropout", None)
1252
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
1253
+ self.dropout = StableDropout(drop_out)
1254
+
1255
+ self.init_weights()
1256
+
1257
+ def get_input_embeddings(self):
1258
+ return self.deberta.get_input_embeddings()
1259
+
1260
+ def set_input_embeddings(self, new_embeddings):
1261
+ self.deberta.set_input_embeddings(new_embeddings)
1262
+
1263
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1264
+ def forward(
1265
+ self,
1266
+ input_ids=None,
1267
+ attention_mask=None,
1268
+ token_type_ids=None,
1269
+ position_ids=None,
1270
+ inputs_embeds=None,
1271
+ labels=None,
1272
+ output_attentions=None,
1273
+ output_hidden_states=None,
1274
+ return_dict=None,
1275
+ ):
1276
+ r"""
1277
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1278
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1279
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1280
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1281
+ """
1282
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1283
+
1284
+ outputs = self.deberta(
1285
+ input_ids,
1286
+ token_type_ids=token_type_ids,
1287
+ attention_mask=attention_mask,
1288
+ position_ids=position_ids,
1289
+ inputs_embeds=inputs_embeds,
1290
+ output_attentions=output_attentions,
1291
+ output_hidden_states=output_hidden_states,
1292
+ return_dict=return_dict,
1293
+ )
1294
+
1295
+ encoder_layer = outputs[0]
1296
+ pooled_output = self.pooler(encoder_layer)
1297
+ pooled_output = self.dropout(pooled_output)
1298
+ logits = self.classifier(pooled_output)
1299
+
1300
+ loss = None
1301
+ if labels is not None:
1302
+ if self.num_labels == 1:
1303
+ # regression task
1304
+ loss_fn = nn.MSELoss()
1305
+ logits = logits.view(-1).to(labels.dtype)
1306
+ loss = loss_fn(logits, labels.view(-1))
1307
+ elif labels.dim() == 1 or labels.size(-1) == 1:
1308
+ label_index = (labels >= 0).nonzero()
1309
+ labels = labels.long()
1310
+ if label_index.size(0) > 0:
1311
+ labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
1312
+ labels = torch.gather(labels, 0, label_index.view(-1))
1313
+ loss_fct = CrossEntropyLoss()
1314
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
1315
+ else:
1316
+ loss = torch.tensor(0).to(logits)
1317
+ else:
1318
+ log_softmax = nn.LogSoftmax(-1)
1319
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
1320
+ if not return_dict:
1321
+ output = (logits,) + outputs[1:]
1322
+ return ((loss,) + output) if loss is not None else output
1323
+ else:
1324
+ return SequenceClassifierOutput(
1325
+ loss=loss,
1326
+ logits=logits,
1327
+ hidden_states=outputs.hidden_states,
1328
+ attentions=outputs.attentions,
1329
+ )
1330
+
1331
+
1332
+ @add_start_docstrings(
1333
+ """
1334
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1335
+ Named-Entity-Recognition (NER) tasks.
1336
+ """,
1337
+ DEBERTA_START_DOCSTRING,
1338
+ )
1339
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
1340
+ class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
1341
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1342
+
1343
+ def __init__(self, config):
1344
+ super().__init__(config)
1345
+ self.num_labels = config.num_labels
1346
+
1347
+ self.deberta = DebertaV2Model(config)
1348
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1349
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1350
+
1351
+ self.init_weights()
1352
+ for param in self.deberta.parameters():
1353
+ param.requires_grad = False
1354
+
1355
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1356
+ def forward(
1357
+ self,
1358
+ input_ids=None,
1359
+ attention_mask=None,
1360
+ token_type_ids=None,
1361
+ position_ids=None,
1362
+ inputs_embeds=None,
1363
+ labels=None,
1364
+ output_attentions=None,
1365
+ output_hidden_states=None,
1366
+ return_dict=None,
1367
+ past_key_values=None,
1368
+ ):
1369
+ r"""
1370
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1371
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1372
+ 1]``.
1373
+ """
1374
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1375
+
1376
+ outputs = self.deberta(
1377
+ input_ids,
1378
+ attention_mask=attention_mask,
1379
+ token_type_ids=token_type_ids,
1380
+ position_ids=position_ids,
1381
+ inputs_embeds=inputs_embeds,
1382
+ output_attentions=output_attentions,
1383
+ output_hidden_states=output_hidden_states,
1384
+ return_dict=return_dict,
1385
+ )
1386
+
1387
+ sequence_output = outputs[0]
1388
+
1389
+ sequence_output = self.dropout(sequence_output)
1390
+ logits = self.classifier(sequence_output)
1391
+
1392
+ loss = None
1393
+ if labels is not None:
1394
+ loss_fct = CrossEntropyLoss()
1395
+ # Only keep active parts of the loss
1396
+ if attention_mask is not None:
1397
+ active_loss = attention_mask.view(-1) == 1
1398
+ active_logits = logits.view(-1, self.num_labels)
1399
+ active_labels = torch.where(
1400
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
1401
+ )
1402
+ loss = loss_fct(active_logits, active_labels)
1403
+ else:
1404
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1405
+
1406
+ if not return_dict:
1407
+ output = (logits,) + outputs[1:]
1408
+ return ((loss,) + output) if loss is not None else output
1409
+
1410
+ return TokenClassifierOutput(
1411
+ loss=loss,
1412
+ logits=logits,
1413
+ hidden_states=outputs.hidden_states,
1414
+ attentions=outputs.attentions,
1415
+ )
1416
+
1417
+
1418
+ @add_start_docstrings(
1419
+ """
1420
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1421
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1422
+ """,
1423
+ DEBERTA_START_DOCSTRING,
1424
+ )
1425
+ # Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2
1426
+ class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
1427
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1428
+
1429
+ def __init__(self, config):
1430
+ super().__init__(config)
1431
+ self.num_labels = config.num_labels
1432
+
1433
+ self.deberta = DebertaV2Model(config)
1434
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1435
+
1436
+ self.init_weights()
1437
+
1438
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1439
+ def forward(
1440
+ self,
1441
+ input_ids=None,
1442
+ attention_mask=None,
1443
+ token_type_ids=None,
1444
+ position_ids=None,
1445
+ inputs_embeds=None,
1446
+ start_positions=None,
1447
+ end_positions=None,
1448
+ output_attentions=None,
1449
+ output_hidden_states=None,
1450
+ return_dict=None,
1451
+ ):
1452
+ r"""
1453
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1454
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1455
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1456
+ sequence are not taken into account for computing the loss.
1457
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1458
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1459
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1460
+ sequence are not taken into account for computing the loss.
1461
+ """
1462
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1463
+
1464
+ outputs = self.deberta(
1465
+ input_ids,
1466
+ attention_mask=attention_mask,
1467
+ token_type_ids=token_type_ids,
1468
+ position_ids=position_ids,
1469
+ inputs_embeds=inputs_embeds,
1470
+ output_attentions=output_attentions,
1471
+ output_hidden_states=output_hidden_states,
1472
+ return_dict=return_dict,
1473
+ )
1474
+
1475
+ sequence_output = outputs[0]
1476
+
1477
+ logits = self.qa_outputs(sequence_output)
1478
+ start_logits, end_logits = logits.split(1, dim=-1)
1479
+ start_logits = start_logits.squeeze(-1).contiguous()
1480
+ end_logits = end_logits.squeeze(-1).contiguous()
1481
+
1482
+ total_loss = None
1483
+ if start_positions is not None and end_positions is not None:
1484
+ # If we are on multi-GPU, split add a dimension
1485
+ if len(start_positions.size()) > 1:
1486
+ start_positions = start_positions.squeeze(-1)
1487
+ if len(end_positions.size()) > 1:
1488
+ end_positions = end_positions.squeeze(-1)
1489
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1490
+ ignored_index = start_logits.size(1)
1491
+ start_positions = start_positions.clamp(0, ignored_index)
1492
+ end_positions = end_positions.clamp(0, ignored_index)
1493
+
1494
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1495
+ start_loss = loss_fct(start_logits, start_positions)
1496
+ end_loss = loss_fct(end_logits, end_positions)
1497
+ total_loss = (start_loss + end_loss) / 2
1498
+
1499
+ if not return_dict:
1500
+ output = (start_logits, end_logits) + outputs[1:]
1501
+ return ((total_loss,) + output) if total_loss is not None else output
1502
+
1503
+ return QuestionAnsweringModelOutput(
1504
+ loss=total_loss,
1505
+ start_logits=start_logits,
1506
+ end_logits=end_logits,
1507
+ hidden_states=outputs.hidden_states,
1508
+ attentions=outputs.attentions,
1509
+ )
soft_prompt/model/multiple_choice.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._C import NoopLogger
3
+ import torch.nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
7
+
8
+ from transformers import BertModel, BertPreTrainedModel
9
+ from transformers import RobertaModel, RobertaPreTrainedModel
10
+ from transformers.modeling_outputs import MultipleChoiceModelOutput, BaseModelOutput, Seq2SeqLMOutput
11
+
12
+ from model.prefix_encoder import PrefixEncoder
13
+ from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout
14
+ from model import utils
15
+
16
+
17
+ class BertForMultipleChoice(BertPreTrainedModel):
18
+ """BERT model for multiple choice tasks.
19
+ This module is composed of the BERT model with a linear layer on top of
20
+ the pooled output.
21
+
22
+ Params:
23
+ `config`: a BertConfig class instance with the configuration to build a new model.
24
+ `num_choices`: the number of classes for the classifier. Default = 2.
25
+
26
+ Inputs:
27
+ `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
28
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
29
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
30
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
31
+ with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
32
+ and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
33
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
34
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
35
+ input sequence length in the current batch. It's the mask that we typically use for attention when
36
+ a batch has varying length sentences.
37
+ `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
38
+ with indices selected in [0, ..., num_choices].
39
+
40
+ Outputs:
41
+ if `labels` is not `None`:
42
+ Outputs the CrossEntropy classification loss of the output with the labels.
43
+ if `labels` is `None`:
44
+ Outputs the classification logits of shape [batch_size, num_labels].
45
+
46
+ Example usage:
47
+ ```python
48
+ # Already been converted into WordPiece token ids
49
+ input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
50
+ input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
51
+ token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
52
+ config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
53
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
54
+
55
+ num_choices = 2
56
+
57
+ model = BertForMultipleChoice(config, num_choices)
58
+ logits = model(input_ids, token_type_ids, input_mask)
59
+ ```
60
+ """
61
+
62
+ def __init__(self, config):
63
+ super().__init__(config)
64
+ self.bert = BertModel(config)
65
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
66
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
67
+
68
+ self.init_weights()
69
+
70
+ self.embedding = utils.get_embeddings(self, config)
71
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
72
+
73
+ def forward(
74
+ self,
75
+ input_ids=None,
76
+ attention_mask=None,
77
+ token_type_ids=None,
78
+ position_ids=None,
79
+ head_mask=None,
80
+ inputs_embeds=None,
81
+ labels=None,
82
+ output_attentions=None,
83
+ output_hidden_states=None,
84
+ return_dict=None,
85
+ use_base_grad=False
86
+ ):
87
+ utils.use_grad(self.bert, use_base_grad)
88
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
89
+ batch_size, num_choices = input_ids.shape[:2]
90
+
91
+ input_ids = input_ids.reshape(-1, input_ids.size(-1))
92
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1))
93
+ attention_mask = attention_mask.reshape(-1, attention_mask.size(-1))
94
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
95
+ inputs_embeds = (
96
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
97
+ if inputs_embeds is not None
98
+ else None
99
+ )
100
+
101
+ outputs = self.bert(
102
+ input_ids,
103
+ attention_mask=attention_mask,
104
+ token_type_ids=token_type_ids,
105
+ position_ids=position_ids,
106
+ head_mask=head_mask,
107
+ inputs_embeds=inputs_embeds,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ return_dict=return_dict,
111
+ )
112
+
113
+ pooled_output = outputs[1]
114
+
115
+ pooled_output = self.dropout(pooled_output)
116
+ logits = self.classifier(pooled_output)
117
+ reshaped_logits = logits.reshape(-1, num_choices)
118
+
119
+ loss = None
120
+ if labels is not None:
121
+ loss_fct = CrossEntropyLoss()
122
+ loss = loss_fct(reshaped_logits, labels)
123
+
124
+ if not return_dict:
125
+ output = (reshaped_logits,) + outputs[2:]
126
+ return ((loss,) + output) if loss is not None else output
127
+
128
+ return MultipleChoiceModelOutput(
129
+ loss=loss,
130
+ logits=reshaped_logits,
131
+ hidden_states=outputs.hidden_states,
132
+ attentions=outputs.attentions,
133
+ )
134
+
135
+ class BertPrefixForMultipleChoice(BertPreTrainedModel):
136
+ def __init__(self, config):
137
+ super().__init__(config)
138
+ self.num_labels = config.num_labels
139
+ self.config = config
140
+ self.bert = BertModel(config)
141
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
142
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
143
+
144
+ for param in self.bert.parameters():
145
+ param.requires_grad = False
146
+
147
+ self.pre_seq_len = config.pre_seq_len
148
+ self.n_layer = config.num_hidden_layers
149
+ self.n_head = config.num_attention_heads
150
+ self.n_embd = config.hidden_size // config.num_attention_heads
151
+
152
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
153
+ self.prefix_encoder = PrefixEncoder(config)
154
+
155
+ bert_param = 0
156
+ for name, param in self.bert.named_parameters():
157
+ bert_param += param.numel()
158
+ all_param = 0
159
+ for name, param in self.named_parameters():
160
+ all_param += param.numel()
161
+ total_param = all_param - bert_param
162
+ print('total param is {}'.format(total_param)) # 9860105
163
+
164
+ self.embedding = utils.get_embeddings(self, config)
165
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
166
+
167
+ def get_prompt(self, batch_size):
168
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
169
+ past_key_values = self.prefix_encoder(prefix_tokens)
170
+ past_key_values = past_key_values.view(
171
+ batch_size,
172
+ self.pre_seq_len,
173
+ self.n_layer * 2,
174
+ self.n_head,
175
+ self.n_embd
176
+ )
177
+ past_key_values = self.dropout(past_key_values)
178
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
179
+ return past_key_values
180
+
181
+ def forward(
182
+ self,
183
+ input_ids=None,
184
+ attention_mask=None,
185
+ token_type_ids=None,
186
+ position_ids=None,
187
+ head_mask=None,
188
+ inputs_embeds=None,
189
+ labels=None,
190
+ output_attentions=None,
191
+ output_hidden_states=None,
192
+ return_dict=None,
193
+ use_base_grad=False
194
+ ):
195
+ utils.use_grad(self.bert, use_base_grad)
196
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
197
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2]
198
+
199
+ input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None
200
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
201
+ attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None
202
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
203
+ inputs_embeds = (
204
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
205
+ if inputs_embeds is not None
206
+ else None
207
+ )
208
+
209
+ past_key_values = self.get_prompt(batch_size=batch_size * num_choices)
210
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device)
211
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
212
+
213
+ outputs = self.bert(
214
+ input_ids,
215
+ attention_mask=attention_mask,
216
+ token_type_ids=token_type_ids,
217
+ position_ids=position_ids,
218
+ head_mask=head_mask,
219
+ inputs_embeds=inputs_embeds,
220
+ output_attentions=output_attentions,
221
+ output_hidden_states=output_hidden_states,
222
+ return_dict=return_dict,
223
+ past_key_values=past_key_values,
224
+ )
225
+
226
+ pooled_output = outputs[1]
227
+
228
+ pooled_output = self.dropout(pooled_output)
229
+ logits = self.classifier(pooled_output)
230
+ reshaped_logits = logits.reshape(-1, num_choices)
231
+
232
+ loss = None
233
+ if labels is not None:
234
+ loss_fct = CrossEntropyLoss()
235
+ loss = loss_fct(reshaped_logits, labels)
236
+
237
+ if not return_dict:
238
+ output = (reshaped_logits,) + outputs[2:]
239
+ return ((loss,) + output) if loss is not None else output
240
+
241
+ return MultipleChoiceModelOutput(
242
+ loss=loss,
243
+ logits=reshaped_logits,
244
+ hidden_states=outputs.hidden_states,
245
+ attentions=outputs.attentions,
246
+ )
247
+
248
+ class RobertaPrefixForMultipleChoice(RobertaPreTrainedModel):
249
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
250
+
251
+ def __init__(self, config):
252
+ super().__init__(config)
253
+
254
+ self.roberta = RobertaModel(config)
255
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
256
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
257
+
258
+ self.init_weights()
259
+
260
+
261
+ for param in self.roberta.parameters():
262
+ param.requires_grad = False
263
+
264
+ self.pre_seq_len = config.pre_seq_len
265
+ self.n_layer = config.num_hidden_layers
266
+ self.n_head = config.num_attention_heads
267
+ self.n_embd = config.hidden_size // config.num_attention_heads
268
+
269
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
270
+ self.prefix_encoder = PrefixEncoder(config)
271
+
272
+ bert_param = 0
273
+ for name, param in self.roberta.named_parameters():
274
+ bert_param += param.numel()
275
+ all_param = 0
276
+ for name, param in self.named_parameters():
277
+ all_param += param.numel()
278
+ total_param = all_param - bert_param
279
+ print('total param is {}'.format(total_param))
280
+
281
+ self.embedding = utils.get_embeddings(self, config)
282
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
283
+
284
+ def get_prompt(self, batch_size):
285
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
286
+ past_key_values = self.prefix_encoder(prefix_tokens)
287
+ past_key_values = past_key_values.view(
288
+ batch_size,
289
+ self.pre_seq_len,
290
+ self.n_layer * 2,
291
+ self.n_head,
292
+ self.n_embd
293
+ )
294
+ past_key_values = self.dropout(past_key_values)
295
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
296
+ return past_key_values
297
+
298
+ def forward(
299
+ self,
300
+ input_ids=None,
301
+ token_type_ids=None,
302
+ attention_mask=None,
303
+ labels=None,
304
+ position_ids=None,
305
+ head_mask=None,
306
+ inputs_embeds=None,
307
+ output_attentions=None,
308
+ output_hidden_states=None,
309
+ return_dict=None,
310
+ use_base_grad=False
311
+ ):
312
+ r"""
313
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
314
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
315
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
316
+ :obj:`input_ids` above)
317
+ """
318
+ utils.use_grad(self.roberta, use_base_grad)
319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
320
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
321
+
322
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
323
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
324
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
325
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
326
+ flat_inputs_embeds = (
327
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
328
+ if inputs_embeds is not None
329
+ else None
330
+ )
331
+
332
+ past_key_values = self.get_prompt(batch_size=batch_size * num_choices)
333
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device)
334
+ flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1)
335
+
336
+ outputs = self.roberta(
337
+ flat_input_ids,
338
+ position_ids=flat_position_ids,
339
+ token_type_ids=flat_token_type_ids,
340
+ attention_mask=flat_attention_mask,
341
+ head_mask=head_mask,
342
+ inputs_embeds=flat_inputs_embeds,
343
+ output_attentions=output_attentions,
344
+ output_hidden_states=output_hidden_states,
345
+ return_dict=return_dict,
346
+ past_key_values=past_key_values,
347
+ )
348
+ pooled_output = outputs[1]
349
+
350
+ pooled_output = self.dropout(pooled_output)
351
+ logits = self.classifier(pooled_output)
352
+ reshaped_logits = logits.view(-1, num_choices)
353
+
354
+ loss = None
355
+ if labels is not None:
356
+ loss_fct = CrossEntropyLoss()
357
+ loss = loss_fct(reshaped_logits, labels)
358
+
359
+ if not return_dict:
360
+ output = (reshaped_logits,) + outputs[2:]
361
+ return ((loss,) + output) if loss is not None else output
362
+
363
+ return MultipleChoiceModelOutput(
364
+ loss=loss,
365
+ logits=reshaped_logits,
366
+ hidden_states=outputs.hidden_states,
367
+ attentions=outputs.attentions,
368
+ )
369
+
370
+ class DebertaPrefixForMultipleChoice(DebertaPreTrainedModel):
371
+ def __init__(self, config):
372
+ super().__init__(config)
373
+ self.num_labels = config.num_labels
374
+ self.config = config
375
+ self.deberta = DebertaModel(config)
376
+ self.pooler = ContextPooler(config)
377
+ output_dim = self.pooler.output_dim
378
+ self.classifier = torch.nn.Linear(output_dim, 1)
379
+ self.dropout = StableDropout(config.hidden_dropout_prob)
380
+ self.init_weights()
381
+
382
+ for param in self.deberta.parameters():
383
+ param.requires_grad = False
384
+
385
+ self.pre_seq_len = config.pre_seq_len
386
+ self.n_layer = config.num_hidden_layers
387
+ self.n_head = config.num_attention_heads
388
+ self.n_embd = config.hidden_size // config.num_attention_heads
389
+
390
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
391
+ self.prefix_encoder = PrefixEncoder(config)
392
+
393
+ deberta_param = 0
394
+ for name, param in self.deberta.named_parameters():
395
+ deberta_param += param.numel()
396
+ all_param = 0
397
+ for name, param in self.named_parameters():
398
+ all_param += param.numel()
399
+ total_param = all_param - deberta_param
400
+ print('total param is {}'.format(total_param))
401
+
402
+ self.embedding = utils.get_embeddings(self, config)
403
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
404
+
405
+ def get_prompt(self, batch_size):
406
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
407
+ past_key_values = self.prefix_encoder(prefix_tokens)
408
+ past_key_values = past_key_values.view(
409
+ batch_size,
410
+ self.pre_seq_len,
411
+ self.n_layer * 2,
412
+ self.n_head,
413
+ self.n_embd
414
+ )
415
+ past_key_values = self.dropout(past_key_values)
416
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
417
+ return past_key_values
418
+
419
+ def forward(
420
+ self,
421
+ input_ids=None,
422
+ attention_mask=None,
423
+ token_type_ids=None,
424
+ position_ids=None,
425
+ head_mask=None,
426
+ inputs_embeds=None,
427
+ labels=None,
428
+ output_attentions=None,
429
+ output_hidden_states=None,
430
+ return_dict=None,
431
+ use_base_grad=False
432
+ ):
433
+ utils.use_grad(self.deberta, use_base_grad)
434
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
435
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
436
+
437
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
438
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
439
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
440
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
441
+ flat_inputs_embeds = (
442
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
443
+ if inputs_embeds is not None
444
+ else None
445
+ )
446
+
447
+ past_key_values = self.get_prompt(batch_size=batch_size * num_choices)
448
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.deberta.device)
449
+ flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1)
450
+
451
+ outputs = self.deberta(
452
+ flat_input_ids,
453
+ attention_mask=flat_attention_mask,
454
+ token_type_ids=flat_token_type_ids,
455
+ position_ids=flat_position_ids,
456
+ inputs_embeds=flat_inputs_embeds,
457
+ output_attentions=output_attentions,
458
+ output_hidden_states=output_hidden_states,
459
+ return_dict=return_dict,
460
+ past_key_values=past_key_values,
461
+ )
462
+
463
+ encoder_layer = outputs[0]
464
+
465
+ pooled_output = self.pooler(encoder_layer)
466
+ pooled_output = self.dropout(pooled_output)
467
+ logits = self.classifier(pooled_output)
468
+ reshaped_logits = logits.view(-1, num_choices)
469
+
470
+ loss = None
471
+ if labels is not None:
472
+ loss_fct = CrossEntropyLoss()
473
+ loss = loss_fct(reshaped_logits, labels)
474
+
475
+ if not return_dict:
476
+ output = (reshaped_logits,) + outputs[2:]
477
+ return ((loss,) + output) if loss is not None else output
478
+
479
+ return MultipleChoiceModelOutput(
480
+ loss=loss,
481
+ logits=reshaped_logits,
482
+ hidden_states=outputs.hidden_states,
483
+ attentions=outputs.attentions,
484
+ )
485
+
486
+
487
+ class BertPromptForMultipleChoice(BertPreTrainedModel):
488
+ def __init__(self, config):
489
+ super().__init__(config)
490
+ self.num_labels = config.num_labels
491
+ self.config = config
492
+ self.bert = BertModel(config)
493
+ self.embeddings = self.bert.embeddings
494
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
495
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
496
+
497
+ for param in self.bert.parameters():
498
+ param.requires_grad = False
499
+
500
+ self.pre_seq_len = config.pre_seq_len
501
+ self.n_layer = config.num_hidden_layers
502
+ self.n_head = config.num_attention_heads
503
+ self.n_embd = config.hidden_size // config.num_attention_heads
504
+
505
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
506
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
507
+
508
+ bert_param = 0
509
+ for name, param in self.bert.named_parameters():
510
+ bert_param += param.numel()
511
+ all_param = 0
512
+ for name, param in self.named_parameters():
513
+ all_param += param.numel()
514
+ total_param = all_param - bert_param
515
+ print('total param is {}'.format(total_param)) # 9860105
516
+
517
+ self.embedding = utils.get_embeddings(self, config)
518
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
519
+
520
+ def get_prompt(self, batch_size):
521
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
522
+ prompts = self.prefix_encoder(prefix_tokens)
523
+ return prompts
524
+
525
+ def forward(
526
+ self,
527
+ input_ids=None,
528
+ attention_mask=None,
529
+ token_type_ids=None,
530
+ position_ids=None,
531
+ head_mask=None,
532
+ inputs_embeds=None,
533
+ labels=None,
534
+ output_attentions=None,
535
+ output_hidden_states=None,
536
+ return_dict=None,
537
+ use_base_grad=False
538
+ ):
539
+ utils.use_grad(self.bert, use_base_grad)
540
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
541
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2]
542
+
543
+ input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None
544
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
545
+ attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None
546
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
547
+ inputs_embeds = (
548
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
549
+ if inputs_embeds is not None
550
+ else None
551
+ )
552
+
553
+ raw_embedding = self.embeddings(
554
+ input_ids=input_ids,
555
+ position_ids=position_ids,
556
+ token_type_ids=token_type_ids,
557
+ )
558
+ prompts = self.get_prompt(batch_size=batch_size * num_choices)
559
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
560
+
561
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device)
562
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
563
+
564
+ outputs = self.bert(
565
+ attention_mask=attention_mask,
566
+ head_mask=head_mask,
567
+ inputs_embeds=inputs_embeds,
568
+ output_attentions=output_attentions,
569
+ output_hidden_states=output_hidden_states,
570
+ return_dict=return_dict,
571
+ )
572
+
573
+ pooled_output = outputs[1]
574
+
575
+ pooled_output = self.dropout(pooled_output)
576
+ logits = self.classifier(pooled_output)
577
+ reshaped_logits = logits.reshape(-1, num_choices)
578
+
579
+ loss = None
580
+ if labels is not None:
581
+ loss_fct = CrossEntropyLoss()
582
+ loss = loss_fct(reshaped_logits, labels)
583
+
584
+ if not return_dict:
585
+ output = (reshaped_logits,) + outputs[2:]
586
+ return ((loss,) + output) if loss is not None else output
587
+
588
+ return MultipleChoiceModelOutput(
589
+ loss=loss,
590
+ logits=reshaped_logits,
591
+ hidden_states=outputs.hidden_states,
592
+ attentions=outputs.attentions,
593
+ )
594
+
595
+
596
+ class RobertaPromptForMultipleChoice(RobertaPreTrainedModel):
597
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
598
+
599
+ def __init__(self, config):
600
+ super().__init__(config)
601
+
602
+ self.roberta = RobertaModel(config)
603
+ self.embeddings = self.roberta.embeddings
604
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
605
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
606
+
607
+ self.init_weights()
608
+
609
+
610
+ for param in self.roberta.parameters():
611
+ param.requires_grad = False
612
+
613
+ self.pre_seq_len = config.pre_seq_len
614
+ self.n_layer = config.num_hidden_layers
615
+ self.n_head = config.num_attention_heads
616
+ self.n_embd = config.hidden_size // config.num_attention_heads
617
+
618
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
619
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
620
+
621
+ bert_param = 0
622
+ for name, param in self.roberta.named_parameters():
623
+ bert_param += param.numel()
624
+ all_param = 0
625
+ for name, param in self.named_parameters():
626
+ all_param += param.numel()
627
+ total_param = all_param - bert_param
628
+ print('total param is {}'.format(total_param))
629
+
630
+ self.embedding = utils.get_embeddings(self, config)
631
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
632
+
633
+ def get_prompt(self, batch_size):
634
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
635
+ prompts = self.prefix_encoder(prefix_tokens)
636
+ return prompts
637
+
638
+ def forward(
639
+ self,
640
+ input_ids=None,
641
+ token_type_ids=None,
642
+ attention_mask=None,
643
+ labels=None,
644
+ position_ids=None,
645
+ head_mask=None,
646
+ inputs_embeds=None,
647
+ output_attentions=None,
648
+ output_hidden_states=None,
649
+ return_dict=None,
650
+ use_base_grad=False
651
+ ):
652
+ r"""
653
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
654
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
655
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
656
+ :obj:`input_ids` above)
657
+ """
658
+ utils.use_grad(self.roberta, use_base_grad)
659
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
661
+
662
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
663
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
664
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
665
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
666
+ inputs_embeds = (
667
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
668
+ if inputs_embeds is not None
669
+ else None
670
+ )
671
+
672
+ raw_embedding = self.embeddings(
673
+ input_ids=input_ids,
674
+ position_ids=position_ids,
675
+ token_type_ids=token_type_ids,
676
+ )
677
+ prompts = self.get_prompt(batch_size=batch_size * num_choices)
678
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
679
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device)
680
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
681
+
682
+ outputs = self.roberta(
683
+ attention_mask=attention_mask,
684
+ head_mask=head_mask,
685
+ inputs_embeds=inputs_embeds,
686
+ output_attentions=output_attentions,
687
+ output_hidden_states=output_hidden_states,
688
+ return_dict=return_dict,
689
+ )
690
+ pooled_output = outputs[1]
691
+
692
+ pooled_output = self.dropout(pooled_output)
693
+ logits = self.classifier(pooled_output)
694
+ reshaped_logits = logits.view(-1, num_choices)
695
+
696
+ loss = None
697
+ if labels is not None:
698
+ loss_fct = CrossEntropyLoss()
699
+ loss = loss_fct(reshaped_logits, labels)
700
+
701
+ if not return_dict:
702
+ output = (reshaped_logits,) + outputs[2:]
703
+ return ((loss,) + output) if loss is not None else output
704
+
705
+ return MultipleChoiceModelOutput(
706
+ loss=loss,
707
+ logits=reshaped_logits,
708
+ hidden_states=outputs.hidden_states,
709
+ attentions=outputs.attentions,
710
+ )
soft_prompt/model/prefix_encoder.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class PrefixEncoder(torch.nn.Module):
5
+ r'''
6
+ The torch.nn model to encode the prefix
7
+
8
+ Input shape: (batch-size, prefix-length)
9
+
10
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
11
+ '''
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.prefix_projection = config.prefix_projection
15
+ if self.prefix_projection:
16
+ # Use a two-layer MLP to encode the prefix
17
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
18
+ self.trans = torch.nn.Sequential(
19
+ torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
20
+ torch.nn.Tanh(),
21
+ torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
22
+ )
23
+ else:
24
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
25
+
26
+ def forward(self, prefix: torch.Tensor):
27
+ device = next(self.embedding.parameters()).device
28
+ if self.prefix_projection:
29
+ prefix_tokens = self.embedding(prefix.to(device))
30
+ past_key_values = self.trans(prefix_tokens)
31
+ else:
32
+ past_key_values = self.embedding(prefix.to(device))
33
+ return past_key_values
soft_prompt/model/question_answering.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn
3
+ from torch.nn import CrossEntropyLoss
4
+ from transformers import BertPreTrainedModel, BertModel, RobertaPreTrainedModel, RobertaModel
5
+ from transformers.modeling_outputs import QuestionAnsweringModelOutput
6
+
7
+ from model.prefix_encoder import PrefixEncoder
8
+ from model.deberta import DebertaPreTrainedModel, DebertaModel
9
+
10
+ class BertForQuestionAnswering(BertPreTrainedModel):
11
+
12
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ self.num_labels = config.num_labels
17
+
18
+ self.bert = BertModel(config, add_pooling_layer=False)
19
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
20
+
21
+ for param in self.bert.parameters():
22
+ param.requires_grad = False
23
+
24
+ self.init_weights()
25
+
26
+ def forward(
27
+ self,
28
+ input_ids=None,
29
+ attention_mask=None,
30
+ token_type_ids=None,
31
+ position_ids=None,
32
+ head_mask=None,
33
+ inputs_embeds=None,
34
+ start_positions=None,
35
+ end_positions=None,
36
+ output_attentions=None,
37
+ output_hidden_states=None,
38
+ return_dict=None,
39
+ ):
40
+ r"""
41
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
42
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
43
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
44
+ sequence are not taken into account for computing the loss.
45
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
46
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
47
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
48
+ sequence are not taken into account for computing the loss.
49
+ """
50
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
51
+
52
+ outputs = self.bert(
53
+ input_ids,
54
+ attention_mask=attention_mask,
55
+ token_type_ids=token_type_ids,
56
+ position_ids=position_ids,
57
+ head_mask=head_mask,
58
+ inputs_embeds=inputs_embeds,
59
+ output_attentions=output_attentions,
60
+ output_hidden_states=output_hidden_states,
61
+ return_dict=return_dict,
62
+ )
63
+
64
+ sequence_output = outputs[0]
65
+
66
+ logits = self.qa_outputs(sequence_output)
67
+ start_logits, end_logits = logits.split(1, dim=-1)
68
+ start_logits = start_logits.squeeze(-1).contiguous()
69
+ end_logits = end_logits.squeeze(-1).contiguous()
70
+
71
+ total_loss = None
72
+ if start_positions is not None and end_positions is not None:
73
+ # If we are on multi-GPU, split add a dimension
74
+ if len(start_positions.size()) > 1:
75
+ start_positions = start_positions.squeeze(-1)
76
+ if len(end_positions.size()) > 1:
77
+ end_positions = end_positions.squeeze(-1)
78
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
79
+ ignored_index = start_logits.size(1)
80
+ start_positions = start_positions.clamp(0, ignored_index)
81
+ end_positions = end_positions.clamp(0, ignored_index)
82
+
83
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
84
+ start_loss = loss_fct(start_logits, start_positions)
85
+ end_loss = loss_fct(end_logits, end_positions)
86
+ total_loss = (start_loss + end_loss) / 2
87
+
88
+ if not return_dict:
89
+ output = (start_logits, end_logits) + outputs[2:]
90
+ return ((total_loss,) + output) if total_loss is not None else output
91
+
92
+ return QuestionAnsweringModelOutput(
93
+ loss=total_loss,
94
+ start_logits=start_logits,
95
+ end_logits=end_logits,
96
+ hidden_states=outputs.hidden_states,
97
+ attentions=outputs.attentions,
98
+ )
99
+
100
+
101
+ class BertPrefixForQuestionAnswering(BertPreTrainedModel):
102
+ def __init__(self, config):
103
+ super().__init__(config)
104
+ self.num_labels = config.num_labels
105
+
106
+ self.pre_seq_len = config.pre_seq_len
107
+ self.n_layer = config.num_hidden_layers
108
+ self.n_head = config.num_attention_heads
109
+ self.n_embd = config.hidden_size // config.num_attention_heads
110
+
111
+ self.bert = BertModel(config, add_pooling_layer=False)
112
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
113
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
114
+ self.prefix_encoder = PrefixEncoder(config)
115
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
116
+
117
+ for param in self.bert.parameters():
118
+ param.requires_grad = False
119
+
120
+ self.init_weights()
121
+
122
+ def get_prompt(self, batch_size):
123
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
124
+ past_key_values = self.prefix_encoder(prefix_tokens)
125
+ bsz, seqlen, _ = past_key_values.shape
126
+ past_key_values = past_key_values.view(
127
+ bsz,
128
+ seqlen,
129
+ self.n_layer * 2,
130
+ self.n_head,
131
+ self.n_embd
132
+ )
133
+ past_key_values = self.dropout(past_key_values)
134
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
135
+ return past_key_values
136
+
137
+ def forward(
138
+ self,
139
+ input_ids=None,
140
+ attention_mask=None,
141
+ token_type_ids=None,
142
+ position_ids=None,
143
+ head_mask=None,
144
+ inputs_embeds=None,
145
+ start_positions=None,
146
+ end_positions=None,
147
+ output_attentions=None,
148
+ output_hidden_states=None,
149
+ return_dict=None,
150
+ ):
151
+ r"""
152
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
153
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
154
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
155
+ sequence are not taken into account for computing the loss.
156
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
157
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
158
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
159
+ sequence are not taken into account for computing the loss.
160
+ """
161
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
162
+
163
+ batch_size = input_ids.shape[0]
164
+ past_key_values = self.get_prompt(batch_size=batch_size)
165
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
166
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
167
+
168
+ outputs = self.bert(
169
+ input_ids,
170
+ attention_mask=attention_mask,
171
+ token_type_ids=token_type_ids,
172
+ position_ids=position_ids,
173
+ head_mask=head_mask,
174
+ inputs_embeds=inputs_embeds,
175
+ output_attentions=output_attentions,
176
+ output_hidden_states=output_hidden_states,
177
+ return_dict=return_dict,
178
+ past_key_values=past_key_values,
179
+ )
180
+
181
+ sequence_output = outputs[0]
182
+
183
+ logits = self.qa_outputs(sequence_output)
184
+ start_logits, end_logits = logits.split(1, dim=-1)
185
+ start_logits = start_logits.squeeze(-1).contiguous()
186
+ end_logits = end_logits.squeeze(-1).contiguous()
187
+
188
+ total_loss = None
189
+ if start_positions is not None and end_positions is not None:
190
+ # If we are on multi-GPU, split add a dimension
191
+ if len(start_positions.size()) > 1:
192
+ start_positions = start_positions.squeeze(-1)
193
+ if len(end_positions.size()) > 1:
194
+ end_positions = end_positions.squeeze(-1)
195
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
196
+ ignored_index = start_logits.size(1)
197
+ start_positions = start_positions.clamp(0, ignored_index)
198
+ end_positions = end_positions.clamp(0, ignored_index)
199
+
200
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
201
+ start_loss = loss_fct(start_logits, start_positions)
202
+ end_loss = loss_fct(end_logits, end_positions)
203
+ total_loss = (start_loss + end_loss) / 2
204
+
205
+ if not return_dict:
206
+ output = (start_logits, end_logits) + outputs[2:]
207
+ return ((total_loss,) + output) if total_loss is not None else output
208
+
209
+ return QuestionAnsweringModelOutput(
210
+ loss=total_loss,
211
+ start_logits=start_logits,
212
+ end_logits=end_logits,
213
+ hidden_states=outputs.hidden_states,
214
+ attentions=outputs.attentions,
215
+ )
216
+
217
+ class RobertaPrefixModelForQuestionAnswering(RobertaPreTrainedModel):
218
+ def __init__(self, config):
219
+ super().__init__(config)
220
+ self.num_labels = config.num_labels
221
+
222
+ self.pre_seq_len = config.pre_seq_len
223
+ self.n_layer = config.num_hidden_layers
224
+ self.n_head = config.num_attention_heads
225
+ self.n_embd = config.hidden_size // config.num_attention_heads
226
+
227
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
228
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
229
+
230
+ self.init_weights()
231
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
232
+ self.prefix_encoder = PrefixEncoder(config)
233
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
234
+
235
+ for param in self.roberta.parameters():
236
+ param.requires_grad = False
237
+
238
+ def get_prompt(self, batch_size):
239
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
240
+ past_key_values = self.prefix_encoder(prefix_tokens)
241
+ bsz, seqlen, _ = past_key_values.shape
242
+ past_key_values = past_key_values.view(
243
+ bsz,
244
+ seqlen,
245
+ self.n_layer * 2,
246
+ self.n_head,
247
+ self.n_embd
248
+ )
249
+ past_key_values = self.dropout(past_key_values)
250
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
251
+ return past_key_values
252
+
253
+ def forward(
254
+ self,
255
+ input_ids=None,
256
+ attention_mask=None,
257
+ token_type_ids=None,
258
+ position_ids=None,
259
+ head_mask=None,
260
+ inputs_embeds=None,
261
+ start_positions=None,
262
+ end_positions=None,
263
+ output_attentions=None,
264
+ output_hidden_states=None,
265
+ return_dict=None,
266
+ ):
267
+ r"""
268
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
269
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
270
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
271
+ sequence are not taken into account for computing the loss.
272
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
273
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
274
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
275
+ sequence are not taken into account for computing the loss.
276
+ """
277
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
278
+
279
+ batch_size = input_ids.shape[0]
280
+ past_key_values = self.get_prompt(batch_size=batch_size)
281
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
282
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
283
+
284
+ outputs = self.roberta(
285
+ input_ids,
286
+ attention_mask=attention_mask,
287
+ token_type_ids=token_type_ids,
288
+ position_ids=position_ids,
289
+ head_mask=head_mask,
290
+ inputs_embeds=inputs_embeds,
291
+ output_attentions=output_attentions,
292
+ output_hidden_states=output_hidden_states,
293
+ return_dict=return_dict,
294
+ past_key_values=past_key_values,
295
+ )
296
+
297
+ sequence_output = outputs[0]
298
+
299
+ logits = self.qa_outputs(sequence_output)
300
+ start_logits, end_logits = logits.split(1, dim=-1)
301
+ start_logits = start_logits.squeeze(-1).contiguous()
302
+ end_logits = end_logits.squeeze(-1).contiguous()
303
+
304
+ total_loss = None
305
+ if start_positions is not None and end_positions is not None:
306
+ # If we are on multi-GPU, split add a dimension
307
+ if len(start_positions.size()) > 1:
308
+ start_positions = start_positions.squeeze(-1)
309
+ if len(end_positions.size()) > 1:
310
+ end_positions = end_positions.squeeze(-1)
311
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
312
+ ignored_index = start_logits.size(1)
313
+ start_positions = start_positions.clamp(0, ignored_index)
314
+ end_positions = end_positions.clamp(0, ignored_index)
315
+
316
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
317
+ start_loss = loss_fct(start_logits, start_positions)
318
+ end_loss = loss_fct(end_logits, end_positions)
319
+ total_loss = (start_loss + end_loss) / 2
320
+
321
+ if not return_dict:
322
+ output = (start_logits, end_logits) + outputs[2:]
323
+ return ((total_loss,) + output) if total_loss is not None else output
324
+
325
+ return QuestionAnsweringModelOutput(
326
+ loss=total_loss,
327
+ start_logits=start_logits,
328
+ end_logits=end_logits,
329
+ hidden_states=outputs.hidden_states,
330
+ attentions=outputs.attentions,
331
+ )
332
+
333
+ class DebertaPrefixModelForQuestionAnswering(DebertaPreTrainedModel):
334
+ def __init__(self, config):
335
+ super().__init__(config)
336
+ self.num_labels = config.num_labels
337
+ self.deberta = DebertaModel(config)
338
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
339
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
340
+ self.init_weights()
341
+
342
+ for param in self.deberta.parameters():
343
+ param.requires_grad = False
344
+
345
+ self.pre_seq_len = config.pre_seq_len
346
+ self.n_layer = config.num_hidden_layers
347
+ self.n_head = config.num_attention_heads
348
+ self.n_embd = config.hidden_size // config.num_attention_heads
349
+
350
+ # Use a two layered MLP to encode the prefix
351
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
352
+ self.prefix_encoder = PrefixEncoder(config)
353
+
354
+ deberta_param = 0
355
+ for name, param in self.deberta.named_parameters():
356
+ deberta_param += param.numel()
357
+ all_param = 0
358
+ for name, param in self.named_parameters():
359
+ all_param += param.numel()
360
+ total_param = all_param - deberta_param
361
+ print('total param is {}'.format(total_param)) # 9860105
362
+
363
+ def get_prompt(self, batch_size):
364
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
365
+ past_key_values = self.prefix_encoder(prefix_tokens)
366
+ # bsz, seqlen, _ = past_key_values.shape
367
+ past_key_values = past_key_values.view(
368
+ batch_size,
369
+ self.pre_seq_len,
370
+ self.n_layer * 2,
371
+ self.n_head,
372
+ self.n_embd
373
+ )
374
+ past_key_values = self.dropout(past_key_values)
375
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
376
+ return past_key_values
377
+
378
+ def forward(
379
+ self,
380
+ input_ids=None,
381
+ attention_mask=None,
382
+ token_type_ids=None,
383
+ position_ids=None,
384
+ # head_mask=None,
385
+ inputs_embeds=None,
386
+ start_positions=None,
387
+ end_positions=None,
388
+ output_attentions=None,
389
+ output_hidden_states=None,
390
+ return_dict=None,
391
+ ):
392
+ r"""
393
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
394
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
395
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
396
+ sequence are not taken into account for computing the loss.
397
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
398
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
399
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
400
+ sequence are not taken into account for computing the loss.
401
+ """
402
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
403
+
404
+ batch_size = input_ids.shape[0]
405
+ past_key_values = self.get_prompt(batch_size=batch_size)
406
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
407
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
408
+
409
+ outputs = self.deberta(
410
+ input_ids,
411
+ attention_mask=attention_mask,
412
+ token_type_ids=token_type_ids,
413
+ position_ids=position_ids,
414
+ inputs_embeds=inputs_embeds,
415
+ output_attentions=output_attentions,
416
+ output_hidden_states=output_hidden_states,
417
+ return_dict=return_dict,
418
+ past_key_values=past_key_values,
419
+ )
420
+
421
+ sequence_output = outputs[0]
422
+
423
+ logits = self.qa_outputs(sequence_output)
424
+ start_logits, end_logits = logits.split(1, dim=-1)
425
+ start_logits = start_logits.squeeze(-1).contiguous()
426
+ end_logits = end_logits.squeeze(-1).contiguous()
427
+
428
+ total_loss = None
429
+ if start_positions is not None and end_positions is not None:
430
+ # If we are on multi-GPU, split add a dimension
431
+ if len(start_positions.size()) > 1:
432
+ start_positions = start_positions.squeeze(-1)
433
+ if len(end_positions.size()) > 1:
434
+ end_positions = end_positions.squeeze(-1)
435
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
436
+ ignored_index = start_logits.size(1)
437
+ start_positions = start_positions.clamp(0, ignored_index)
438
+ end_positions = end_positions.clamp(0, ignored_index)
439
+
440
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
441
+ start_loss = loss_fct(start_logits, start_positions)
442
+ end_loss = loss_fct(end_logits, end_positions)
443
+ total_loss = (start_loss + end_loss) / 2
444
+
445
+ if not return_dict:
446
+ output = (start_logits, end_logits) + outputs[2:]
447
+ return ((total_loss,) + output) if total_loss is not None else output
448
+
449
+ return QuestionAnsweringModelOutput(
450
+ loss=total_loss,
451
+ start_logits=start_logits,
452
+ end_logits=end_logits,
453
+ hidden_states=outputs.hidden_states,
454
+ attentions=outputs.attentions,
455
+ )
soft_prompt/model/roberta.py ADDED
@@ -0,0 +1,1588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RoBERTa model."""
17
+
18
+ import math
19
+ from typing import List, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.utils.checkpoint
23
+ from torch import nn
24
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
+
26
+ from ...activations import ACT2FN, gelu
27
+ from ...modeling_outputs import (
28
+ BaseModelOutputWithPastAndCrossAttentions,
29
+ BaseModelOutputWithPoolingAndCrossAttentions,
30
+ CausalLMOutputWithCrossAttentions,
31
+ MaskedLMOutput,
32
+ MultipleChoiceModelOutput,
33
+ QuestionAnsweringModelOutput,
34
+ SequenceClassifierOutput,
35
+ TokenClassifierOutput,
36
+ )
37
+ from ...modeling_utils import PreTrainedModel
38
+ from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
39
+ from ...utils import (
40
+ add_code_sample_docstrings,
41
+ add_start_docstrings,
42
+ add_start_docstrings_to_model_forward,
43
+ logging,
44
+ replace_return_docstrings,
45
+ )
46
+ from .configuration_roberta import RobertaConfig
47
+
48
+
49
+ logger = logging.get_logger(__name__)
50
+
51
+ _CHECKPOINT_FOR_DOC = "roberta-base"
52
+ _CONFIG_FOR_DOC = "RobertaConfig"
53
+
54
+ ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
55
+ "roberta-base",
56
+ "roberta-large",
57
+ "roberta-large-mnli",
58
+ "distilroberta-base",
59
+ "roberta-base-openai-detector",
60
+ "roberta-large-openai-detector",
61
+ # See all RoBERTa models at https://huggingface.co/models?filter=roberta
62
+ ]
63
+
64
+
65
+ class RobertaEmbeddings(nn.Module):
66
+ """
67
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
68
+ """
69
+
70
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
71
+ def __init__(self, config):
72
+ super().__init__()
73
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
74
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
75
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
76
+
77
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
78
+ # any TensorFlow checkpoint file
79
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
80
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
81
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
82
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
83
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
84
+ self.register_buffer(
85
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
86
+ )
87
+
88
+ # End copy
89
+ self.padding_idx = config.pad_token_id
90
+ self.position_embeddings = nn.Embedding(
91
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
92
+ )
93
+
94
+ def forward(
95
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
96
+ ):
97
+ if position_ids is None:
98
+ if input_ids is not None:
99
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
100
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
101
+ else:
102
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
103
+
104
+ if input_ids is not None:
105
+ input_shape = input_ids.size()
106
+ else:
107
+ input_shape = inputs_embeds.size()[:-1]
108
+
109
+ seq_length = input_shape[1]
110
+
111
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
112
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
113
+ # issue #5664
114
+ if token_type_ids is None:
115
+ if hasattr(self, "token_type_ids"):
116
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
117
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
118
+ token_type_ids = buffered_token_type_ids_expanded
119
+ else:
120
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
121
+
122
+ if inputs_embeds is None:
123
+ inputs_embeds = self.word_embeddings(input_ids)
124
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
125
+
126
+ embeddings = inputs_embeds + token_type_embeddings
127
+ if self.position_embedding_type == "absolute":
128
+ position_embeddings = self.position_embeddings(position_ids)
129
+ embeddings += position_embeddings
130
+ embeddings = self.LayerNorm(embeddings)
131
+ embeddings = self.dropout(embeddings)
132
+ return embeddings
133
+
134
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
135
+ """
136
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
137
+
138
+ Args:
139
+ inputs_embeds: torch.Tensor
140
+
141
+ Returns: torch.Tensor
142
+ """
143
+ input_shape = inputs_embeds.size()[:-1]
144
+ sequence_length = input_shape[1]
145
+
146
+ position_ids = torch.arange(
147
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
148
+ )
149
+ return position_ids.unsqueeze(0).expand(input_shape)
150
+
151
+
152
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
153
+ class RobertaSelfAttention(nn.Module):
154
+ def __init__(self, config, position_embedding_type=None):
155
+ super().__init__()
156
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
157
+ raise ValueError(
158
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
159
+ f"heads ({config.num_attention_heads})"
160
+ )
161
+
162
+ self.num_attention_heads = config.num_attention_heads
163
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
164
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
165
+
166
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
167
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
168
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
169
+
170
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
171
+ self.position_embedding_type = position_embedding_type or getattr(
172
+ config, "position_embedding_type", "absolute"
173
+ )
174
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
175
+ self.max_position_embeddings = config.max_position_embeddings
176
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
177
+
178
+ self.is_decoder = config.is_decoder
179
+
180
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
181
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
182
+ x = x.view(new_x_shape)
183
+ return x.permute(0, 2, 1, 3)
184
+
185
+ def forward(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ attention_mask: Optional[torch.FloatTensor] = None,
189
+ head_mask: Optional[torch.FloatTensor] = None,
190
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
191
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
192
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
193
+ output_attentions: Optional[bool] = False,
194
+ ) -> Tuple[torch.Tensor]:
195
+ mixed_query_layer = self.query(hidden_states)
196
+
197
+ # If this is instantiated as a cross-attention module, the keys
198
+ # and values come from an encoder; the attention mask needs to be
199
+ # such that the encoder's padding tokens are not attended to.
200
+ is_cross_attention = encoder_hidden_states is not None
201
+
202
+ if is_cross_attention and past_key_value is not None:
203
+ # reuse k,v, cross_attentions
204
+ key_layer = past_key_value[0]
205
+ value_layer = past_key_value[1]
206
+ attention_mask = encoder_attention_mask
207
+ elif is_cross_attention:
208
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
209
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
210
+ attention_mask = encoder_attention_mask
211
+ elif past_key_value is not None:
212
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
213
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
214
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
215
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
216
+ else:
217
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
218
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
219
+
220
+ query_layer = self.transpose_for_scores(mixed_query_layer)
221
+
222
+ use_cache = past_key_value is not None
223
+ if self.is_decoder:
224
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
225
+ # Further calls to cross_attention layer can then reuse all cross-attention
226
+ # key/value_states (first "if" case)
227
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
228
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
229
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
230
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
231
+ past_key_value = (key_layer, value_layer)
232
+
233
+ # Take the dot product between "query" and "key" to get the raw attention scores.
234
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
235
+
236
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
237
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
238
+ if use_cache:
239
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
240
+ -1, 1
241
+ )
242
+ else:
243
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
244
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
245
+ distance = position_ids_l - position_ids_r
246
+
247
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
248
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
249
+
250
+ if self.position_embedding_type == "relative_key":
251
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
252
+ attention_scores = attention_scores + relative_position_scores
253
+ elif self.position_embedding_type == "relative_key_query":
254
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
255
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
256
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
257
+
258
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
259
+ if attention_mask is not None:
260
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
261
+ attention_scores = attention_scores + attention_mask
262
+
263
+ # Normalize the attention scores to probabilities.
264
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
265
+
266
+ # This is actually dropping out entire tokens to attend to, which might
267
+ # seem a bit unusual, but is taken from the original Transformer paper.
268
+ attention_probs = self.dropout(attention_probs)
269
+
270
+ # Mask heads if we want to
271
+ if head_mask is not None:
272
+ attention_probs = attention_probs * head_mask
273
+
274
+ context_layer = torch.matmul(attention_probs, value_layer)
275
+
276
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
277
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
278
+ context_layer = context_layer.view(new_context_layer_shape)
279
+
280
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
281
+
282
+ if self.is_decoder:
283
+ outputs = outputs + (past_key_value,)
284
+ return outputs
285
+
286
+
287
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
288
+ class RobertaSelfOutput(nn.Module):
289
+ def __init__(self, config):
290
+ super().__init__()
291
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
292
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
293
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
294
+
295
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
296
+ hidden_states = self.dense(hidden_states)
297
+ hidden_states = self.dropout(hidden_states)
298
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
299
+ return hidden_states
300
+
301
+
302
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
303
+ class RobertaAttention(nn.Module):
304
+ def __init__(self, config, position_embedding_type=None):
305
+ super().__init__()
306
+ self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
307
+ self.output = RobertaSelfOutput(config)
308
+ self.pruned_heads = set()
309
+
310
+ def prune_heads(self, heads):
311
+ if len(heads) == 0:
312
+ return
313
+ heads, index = find_pruneable_heads_and_indices(
314
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
315
+ )
316
+
317
+ # Prune linear layers
318
+ self.self.query = prune_linear_layer(self.self.query, index)
319
+ self.self.key = prune_linear_layer(self.self.key, index)
320
+ self.self.value = prune_linear_layer(self.self.value, index)
321
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
322
+
323
+ # Update hyper params and store pruned heads
324
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
325
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
326
+ self.pruned_heads = self.pruned_heads.union(heads)
327
+
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ attention_mask: Optional[torch.FloatTensor] = None,
332
+ head_mask: Optional[torch.FloatTensor] = None,
333
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
334
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
335
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
336
+ output_attentions: Optional[bool] = False,
337
+ ) -> Tuple[torch.Tensor]:
338
+ self_outputs = self.self(
339
+ hidden_states,
340
+ attention_mask,
341
+ head_mask,
342
+ encoder_hidden_states,
343
+ encoder_attention_mask,
344
+ past_key_value,
345
+ output_attentions,
346
+ )
347
+ attention_output = self.output(self_outputs[0], hidden_states)
348
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
349
+ return outputs
350
+
351
+
352
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate
353
+ class RobertaIntermediate(nn.Module):
354
+ def __init__(self, config):
355
+ super().__init__()
356
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
357
+ if isinstance(config.hidden_act, str):
358
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
359
+ else:
360
+ self.intermediate_act_fn = config.hidden_act
361
+
362
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
363
+ hidden_states = self.dense(hidden_states)
364
+ hidden_states = self.intermediate_act_fn(hidden_states)
365
+ return hidden_states
366
+
367
+
368
+ # Copied from transformers.models.bert.modeling_bert.BertOutput
369
+ class RobertaOutput(nn.Module):
370
+ def __init__(self, config):
371
+ super().__init__()
372
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
373
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
374
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
375
+
376
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
377
+ hidden_states = self.dense(hidden_states)
378
+ hidden_states = self.dropout(hidden_states)
379
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
380
+ return hidden_states
381
+
382
+
383
+ # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
384
+ class RobertaLayer(nn.Module):
385
+ def __init__(self, config):
386
+ super().__init__()
387
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
388
+ self.seq_len_dim = 1
389
+ self.attention = RobertaAttention(config)
390
+ self.is_decoder = config.is_decoder
391
+ self.add_cross_attention = config.add_cross_attention
392
+ if self.add_cross_attention:
393
+ if not self.is_decoder:
394
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
395
+ self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
396
+ self.intermediate = RobertaIntermediate(config)
397
+ self.output = RobertaOutput(config)
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: torch.Tensor,
402
+ attention_mask: Optional[torch.FloatTensor] = None,
403
+ head_mask: Optional[torch.FloatTensor] = None,
404
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
405
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
406
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
407
+ output_attentions: Optional[bool] = False,
408
+ ) -> Tuple[torch.Tensor]:
409
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
410
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
411
+ self_attention_outputs = self.attention(
412
+ hidden_states,
413
+ attention_mask,
414
+ head_mask,
415
+ output_attentions=output_attentions,
416
+ past_key_value=self_attn_past_key_value,
417
+ )
418
+ attention_output = self_attention_outputs[0]
419
+
420
+ # if decoder, the last output is tuple of self-attn cache
421
+ if self.is_decoder:
422
+ outputs = self_attention_outputs[1:-1]
423
+ present_key_value = self_attention_outputs[-1]
424
+ else:
425
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
426
+
427
+ cross_attn_present_key_value = None
428
+ if self.is_decoder and encoder_hidden_states is not None:
429
+ if not hasattr(self, "crossattention"):
430
+ raise ValueError(
431
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
432
+ " by setting `config.add_cross_attention=True`"
433
+ )
434
+
435
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
436
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
437
+ cross_attention_outputs = self.crossattention(
438
+ attention_output,
439
+ attention_mask,
440
+ head_mask,
441
+ encoder_hidden_states,
442
+ encoder_attention_mask,
443
+ cross_attn_past_key_value,
444
+ output_attentions,
445
+ )
446
+ attention_output = cross_attention_outputs[0]
447
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
448
+
449
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
450
+ cross_attn_present_key_value = cross_attention_outputs[-1]
451
+ present_key_value = present_key_value + cross_attn_present_key_value
452
+
453
+ layer_output = apply_chunking_to_forward(
454
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
455
+ )
456
+ outputs = (layer_output,) + outputs
457
+
458
+ # if decoder, return the attn key/values as the last output
459
+ if self.is_decoder:
460
+ outputs = outputs + (present_key_value,)
461
+
462
+ return outputs
463
+
464
+ def feed_forward_chunk(self, attention_output):
465
+ intermediate_output = self.intermediate(attention_output)
466
+ layer_output = self.output(intermediate_output, attention_output)
467
+ return layer_output
468
+
469
+
470
+ # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
471
+ class RobertaEncoder(nn.Module):
472
+ def __init__(self, config):
473
+ super().__init__()
474
+ self.config = config
475
+ self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
476
+ self.gradient_checkpointing = False
477
+
478
+ def forward(
479
+ self,
480
+ hidden_states: torch.Tensor,
481
+ attention_mask: Optional[torch.FloatTensor] = None,
482
+ head_mask: Optional[torch.FloatTensor] = None,
483
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
484
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
485
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
486
+ use_cache: Optional[bool] = None,
487
+ output_attentions: Optional[bool] = False,
488
+ output_hidden_states: Optional[bool] = False,
489
+ return_dict: Optional[bool] = True,
490
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
491
+ all_hidden_states = () if output_hidden_states else None
492
+ all_self_attentions = () if output_attentions else None
493
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
494
+
495
+ if self.gradient_checkpointing and self.training:
496
+ if use_cache:
497
+ logger.warning_once(
498
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
499
+ )
500
+ use_cache = False
501
+
502
+ next_decoder_cache = () if use_cache else None
503
+ for i, layer_module in enumerate(self.layer):
504
+ if output_hidden_states:
505
+ all_hidden_states = all_hidden_states + (hidden_states,)
506
+
507
+ layer_head_mask = head_mask[i] if head_mask is not None else None
508
+ past_key_value = past_key_values[i] if past_key_values is not None else None
509
+
510
+ if self.gradient_checkpointing and self.training:
511
+
512
+ def create_custom_forward(module):
513
+ def custom_forward(*inputs):
514
+ return module(*inputs, past_key_value, output_attentions)
515
+
516
+ return custom_forward
517
+
518
+ layer_outputs = torch.utils.checkpoint.checkpoint(
519
+ create_custom_forward(layer_module),
520
+ hidden_states,
521
+ attention_mask,
522
+ layer_head_mask,
523
+ encoder_hidden_states,
524
+ encoder_attention_mask,
525
+ )
526
+ else:
527
+ layer_outputs = layer_module(
528
+ hidden_states,
529
+ attention_mask,
530
+ layer_head_mask,
531
+ encoder_hidden_states,
532
+ encoder_attention_mask,
533
+ past_key_value,
534
+ output_attentions,
535
+ )
536
+
537
+ hidden_states = layer_outputs[0]
538
+ if use_cache:
539
+ next_decoder_cache += (layer_outputs[-1],)
540
+ if output_attentions:
541
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
542
+ if self.config.add_cross_attention:
543
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
544
+
545
+ if output_hidden_states:
546
+ all_hidden_states = all_hidden_states + (hidden_states,)
547
+
548
+ if not return_dict:
549
+ return tuple(
550
+ v
551
+ for v in [
552
+ hidden_states,
553
+ next_decoder_cache,
554
+ all_hidden_states,
555
+ all_self_attentions,
556
+ all_cross_attentions,
557
+ ]
558
+ if v is not None
559
+ )
560
+ return BaseModelOutputWithPastAndCrossAttentions(
561
+ last_hidden_state=hidden_states,
562
+ past_key_values=next_decoder_cache,
563
+ hidden_states=all_hidden_states,
564
+ attentions=all_self_attentions,
565
+ cross_attentions=all_cross_attentions,
566
+ )
567
+
568
+
569
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
570
+ class RobertaPooler(nn.Module):
571
+ def __init__(self, config):
572
+ super().__init__()
573
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
574
+ self.activation = nn.Tanh()
575
+
576
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
577
+ # We "pool" the model by simply taking the hidden state corresponding
578
+ # to the first token.
579
+ first_token_tensor = hidden_states[:, 0]
580
+ pooled_output = self.dense(first_token_tensor)
581
+ pooled_output = self.activation(pooled_output)
582
+ return pooled_output
583
+
584
+
585
+ class RobertaPreTrainedModel(PreTrainedModel):
586
+ """
587
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
588
+ models.
589
+ """
590
+
591
+ config_class = RobertaConfig
592
+ base_model_prefix = "roberta"
593
+ supports_gradient_checkpointing = True
594
+ _no_split_modules = []
595
+
596
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
597
+ def _init_weights(self, module):
598
+ """Initialize the weights"""
599
+ if isinstance(module, nn.Linear):
600
+ # Slightly different from the TF version which uses truncated_normal for initialization
601
+ # cf https://github.com/pytorch/pytorch/pull/5617
602
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
603
+ if module.bias is not None:
604
+ module.bias.data.zero_()
605
+ elif isinstance(module, nn.Embedding):
606
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
607
+ if module.padding_idx is not None:
608
+ module.weight.data[module.padding_idx].zero_()
609
+ elif isinstance(module, nn.LayerNorm):
610
+ module.bias.data.zero_()
611
+ module.weight.data.fill_(1.0)
612
+
613
+ def _set_gradient_checkpointing(self, module, value=False):
614
+ if isinstance(module, RobertaEncoder):
615
+ module.gradient_checkpointing = value
616
+
617
+ def update_keys_to_ignore(self, config, del_keys_to_ignore):
618
+ """Remove some keys from ignore list"""
619
+ if not config.tie_word_embeddings:
620
+ # must make a new list, or the class variable gets modified!
621
+ self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
622
+ self._keys_to_ignore_on_load_missing = [
623
+ k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
624
+ ]
625
+
626
+
627
+ ROBERTA_START_DOCSTRING = r"""
628
+
629
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
630
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
631
+ etc.)
632
+
633
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
634
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
635
+ and behavior.
636
+
637
+ Parameters:
638
+ config ([`RobertaConfig`]): Model configuration class with all the parameters of the
639
+ model. Initializing with a config file does not load the weights associated with the model, only the
640
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
641
+ """
642
+
643
+ ROBERTA_INPUTS_DOCSTRING = r"""
644
+ Args:
645
+ input_ids (`torch.LongTensor` of shape `({0})`):
646
+ Indices of input sequence tokens in the vocabulary.
647
+
648
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
649
+ [`PreTrainedTokenizer.__call__`] for details.
650
+
651
+ [What are input IDs?](../glossary#input-ids)
652
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
653
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
654
+
655
+ - 1 for tokens that are **not masked**,
656
+ - 0 for tokens that are **masked**.
657
+
658
+ [What are attention masks?](../glossary#attention-mask)
659
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
660
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
661
+
662
+ - 0 corresponds to a *sentence A* token,
663
+ - 1 corresponds to a *sentence B* token.
664
+ This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
665
+ >= 2. All the value in this tensor should be always < type_vocab_size.
666
+
667
+ [What are token type IDs?](../glossary#token-type-ids)
668
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
669
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
670
+ config.max_position_embeddings - 1]`.
671
+
672
+ [What are position IDs?](../glossary#position-ids)
673
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
674
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
675
+
676
+ - 1 indicates the head is **not masked**,
677
+ - 0 indicates the head is **masked**.
678
+
679
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
680
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
681
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
682
+ model's internal embedding lookup matrix.
683
+ output_attentions (`bool`, *optional*):
684
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
685
+ tensors for more detail.
686
+ output_hidden_states (`bool`, *optional*):
687
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
688
+ more detail.
689
+ return_dict (`bool`, *optional*):
690
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
691
+ """
692
+
693
+
694
+ @add_start_docstrings(
695
+ "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
696
+ ROBERTA_START_DOCSTRING,
697
+ )
698
+ class RobertaModel(RobertaPreTrainedModel):
699
+ """
700
+
701
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
702
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
703
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
704
+ Kaiser and Illia Polosukhin.
705
+
706
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
707
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
708
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
709
+
710
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
711
+
712
+ """
713
+
714
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
715
+
716
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
717
+ def __init__(self, config, add_pooling_layer=True):
718
+ super().__init__(config)
719
+ self.config = config
720
+
721
+ self.embeddings = RobertaEmbeddings(config)
722
+ self.encoder = RobertaEncoder(config)
723
+
724
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
725
+
726
+ # Initialize weights and apply final processing
727
+ self.post_init()
728
+
729
+ def get_input_embeddings(self):
730
+ return self.embeddings.word_embeddings
731
+
732
+ def set_input_embeddings(self, value):
733
+ self.embeddings.word_embeddings = value
734
+
735
+ def _prune_heads(self, heads_to_prune):
736
+ """
737
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
738
+ class PreTrainedModel
739
+ """
740
+ for layer, heads in heads_to_prune.items():
741
+ self.encoder.layer[layer].attention.prune_heads(heads)
742
+
743
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
744
+ @add_code_sample_docstrings(
745
+ checkpoint=_CHECKPOINT_FOR_DOC,
746
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
747
+ config_class=_CONFIG_FOR_DOC,
748
+ )
749
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
750
+ def forward(
751
+ self,
752
+ input_ids: Optional[torch.Tensor] = None,
753
+ attention_mask: Optional[torch.Tensor] = None,
754
+ token_type_ids: Optional[torch.Tensor] = None,
755
+ position_ids: Optional[torch.Tensor] = None,
756
+ head_mask: Optional[torch.Tensor] = None,
757
+ inputs_embeds: Optional[torch.Tensor] = None,
758
+ encoder_hidden_states: Optional[torch.Tensor] = None,
759
+ encoder_attention_mask: Optional[torch.Tensor] = None,
760
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
761
+ use_cache: Optional[bool] = None,
762
+ output_attentions: Optional[bool] = None,
763
+ output_hidden_states: Optional[bool] = None,
764
+ return_dict: Optional[bool] = None,
765
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
766
+ r"""
767
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
768
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
769
+ the model is configured as a decoder.
770
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
771
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
772
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
773
+
774
+ - 1 for tokens that are **not masked**,
775
+ - 0 for tokens that are **masked**.
776
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
777
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
778
+
779
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
780
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
781
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
782
+ use_cache (`bool`, *optional*):
783
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
784
+ `past_key_values`).
785
+ """
786
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
787
+ output_hidden_states = (
788
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
789
+ )
790
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
791
+
792
+ if self.config.is_decoder:
793
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
794
+ else:
795
+ use_cache = False
796
+
797
+ if input_ids is not None and inputs_embeds is not None:
798
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
799
+ elif input_ids is not None:
800
+ input_shape = input_ids.size()
801
+ elif inputs_embeds is not None:
802
+ input_shape = inputs_embeds.size()[:-1]
803
+ else:
804
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
805
+
806
+ batch_size, seq_length = input_shape
807
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
808
+
809
+ # past_key_values_length
810
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
811
+
812
+ if attention_mask is None:
813
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
814
+
815
+ if token_type_ids is None:
816
+ if hasattr(self.embeddings, "token_type_ids"):
817
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
818
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
819
+ token_type_ids = buffered_token_type_ids_expanded
820
+ else:
821
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
822
+
823
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
824
+ # ourselves in which case we just need to make it broadcastable to all heads.
825
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
826
+
827
+ # If a 2D or 3D attention mask is provided for the cross-attention
828
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
829
+ if self.config.is_decoder and encoder_hidden_states is not None:
830
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
831
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
832
+ if encoder_attention_mask is None:
833
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
834
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
835
+ else:
836
+ encoder_extended_attention_mask = None
837
+
838
+ # Prepare head mask if needed
839
+ # 1.0 in head_mask indicate we keep the head
840
+ # attention_probs has shape bsz x n_heads x N x N
841
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
842
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
843
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
844
+
845
+ embedding_output = self.embeddings(
846
+ input_ids=input_ids,
847
+ position_ids=position_ids,
848
+ token_type_ids=token_type_ids,
849
+ inputs_embeds=inputs_embeds,
850
+ past_key_values_length=past_key_values_length,
851
+ )
852
+ encoder_outputs = self.encoder(
853
+ embedding_output,
854
+ attention_mask=extended_attention_mask,
855
+ head_mask=head_mask,
856
+ encoder_hidden_states=encoder_hidden_states,
857
+ encoder_attention_mask=encoder_extended_attention_mask,
858
+ past_key_values=past_key_values,
859
+ use_cache=use_cache,
860
+ output_attentions=output_attentions,
861
+ output_hidden_states=output_hidden_states,
862
+ return_dict=return_dict,
863
+ )
864
+ sequence_output = encoder_outputs[0]
865
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
866
+
867
+ if not return_dict:
868
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
869
+
870
+ return BaseModelOutputWithPoolingAndCrossAttentions(
871
+ last_hidden_state=sequence_output,
872
+ pooler_output=pooled_output,
873
+ past_key_values=encoder_outputs.past_key_values,
874
+ hidden_states=encoder_outputs.hidden_states,
875
+ attentions=encoder_outputs.attentions,
876
+ cross_attentions=encoder_outputs.cross_attentions,
877
+ )
878
+
879
+
880
+ @add_start_docstrings(
881
+ """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING
882
+ )
883
+ class RobertaForCausalLM(RobertaPreTrainedModel):
884
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
885
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
886
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
887
+
888
+ def __init__(self, config):
889
+ super().__init__(config)
890
+
891
+ if not config.is_decoder:
892
+ logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
893
+
894
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
895
+ self.lm_head = RobertaLMHead(config)
896
+
897
+ # The LM head weights require special treatment only when they are tied with the word embeddings
898
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
899
+
900
+ # Initialize weights and apply final processing
901
+ self.post_init()
902
+
903
+ def get_output_embeddings(self):
904
+ return self.lm_head.decoder
905
+
906
+ def set_output_embeddings(self, new_embeddings):
907
+ self.lm_head.decoder = new_embeddings
908
+
909
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
910
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
911
+ def forward(
912
+ self,
913
+ input_ids: Optional[torch.LongTensor] = None,
914
+ attention_mask: Optional[torch.FloatTensor] = None,
915
+ token_type_ids: Optional[torch.LongTensor] = None,
916
+ position_ids: Optional[torch.LongTensor] = None,
917
+ head_mask: Optional[torch.FloatTensor] = None,
918
+ inputs_embeds: Optional[torch.FloatTensor] = None,
919
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
920
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
921
+ labels: Optional[torch.LongTensor] = None,
922
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
923
+ use_cache: Optional[bool] = None,
924
+ output_attentions: Optional[bool] = None,
925
+ output_hidden_states: Optional[bool] = None,
926
+ return_dict: Optional[bool] = None,
927
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
928
+ r"""
929
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
930
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
931
+ the model is configured as a decoder.
932
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
933
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
934
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
935
+
936
+ - 1 for tokens that are **not masked**,
937
+ - 0 for tokens that are **masked**.
938
+
939
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
940
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
941
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
942
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
943
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
944
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
945
+
946
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
947
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
948
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
949
+ use_cache (`bool`, *optional*):
950
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
951
+ `past_key_values`).
952
+
953
+ Returns:
954
+
955
+ Example:
956
+
957
+ ```python
958
+ >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig
959
+ >>> import torch
960
+
961
+ >>> tokenizer = AutoTokenizer.from_pretrained("roberta-base")
962
+ >>> config = AutoConfig.from_pretrained("roberta-base")
963
+ >>> config.is_decoder = True
964
+ >>> model = RobertaForCausalLM.from_pretrained("roberta-base", config=config)
965
+
966
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
967
+ >>> outputs = model(**inputs)
968
+
969
+ >>> prediction_logits = outputs.logits
970
+ ```"""
971
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
972
+ if labels is not None:
973
+ use_cache = False
974
+
975
+ outputs = self.roberta(
976
+ input_ids,
977
+ attention_mask=attention_mask,
978
+ token_type_ids=token_type_ids,
979
+ position_ids=position_ids,
980
+ head_mask=head_mask,
981
+ inputs_embeds=inputs_embeds,
982
+ encoder_hidden_states=encoder_hidden_states,
983
+ encoder_attention_mask=encoder_attention_mask,
984
+ past_key_values=past_key_values,
985
+ use_cache=use_cache,
986
+ output_attentions=output_attentions,
987
+ output_hidden_states=output_hidden_states,
988
+ return_dict=return_dict,
989
+ )
990
+
991
+ sequence_output = outputs[0]
992
+ prediction_scores = self.lm_head(sequence_output)
993
+
994
+ lm_loss = None
995
+ if labels is not None:
996
+ # move labels to correct device to enable model parallelism
997
+ labels = labels.to(prediction_scores.device)
998
+ # we are doing next-token prediction; shift prediction scores and input ids by one
999
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1000
+ labels = labels[:, 1:].contiguous()
1001
+ loss_fct = CrossEntropyLoss()
1002
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1003
+
1004
+ if not return_dict:
1005
+ output = (prediction_scores,) + outputs[2:]
1006
+ return ((lm_loss,) + output) if lm_loss is not None else output
1007
+
1008
+ return CausalLMOutputWithCrossAttentions(
1009
+ loss=lm_loss,
1010
+ logits=prediction_scores,
1011
+ past_key_values=outputs.past_key_values,
1012
+ hidden_states=outputs.hidden_states,
1013
+ attentions=outputs.attentions,
1014
+ cross_attentions=outputs.cross_attentions,
1015
+ )
1016
+
1017
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
1018
+ input_shape = input_ids.shape
1019
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1020
+ if attention_mask is None:
1021
+ attention_mask = input_ids.new_ones(input_shape)
1022
+
1023
+ # cut decoder_input_ids if past is used
1024
+ if past_key_values is not None:
1025
+ input_ids = input_ids[:, -1:]
1026
+
1027
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
1028
+
1029
+ def _reorder_cache(self, past_key_values, beam_idx):
1030
+ reordered_past = ()
1031
+ for layer_past in past_key_values:
1032
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
1033
+ return reordered_past
1034
+
1035
+
1036
+ @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
1037
+ class RobertaForMaskedLM(RobertaPreTrainedModel):
1038
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1039
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
1040
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1041
+
1042
+ def __init__(self, config):
1043
+ super().__init__(config)
1044
+
1045
+ if config.is_decoder:
1046
+ logger.warning(
1047
+ "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
1048
+ "bi-directional self-attention."
1049
+ )
1050
+
1051
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1052
+ self.lm_head = RobertaLMHead(config)
1053
+
1054
+ # The LM head weights require special treatment only when they are tied with the word embeddings
1055
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
1056
+
1057
+ # Initialize weights and apply final processing
1058
+ self.post_init()
1059
+
1060
+ def get_output_embeddings(self):
1061
+ return self.lm_head.decoder
1062
+
1063
+ def set_output_embeddings(self, new_embeddings):
1064
+ self.lm_head.decoder = new_embeddings
1065
+
1066
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1067
+ @add_code_sample_docstrings(
1068
+ checkpoint=_CHECKPOINT_FOR_DOC,
1069
+ output_type=MaskedLMOutput,
1070
+ config_class=_CONFIG_FOR_DOC,
1071
+ mask="<mask>",
1072
+ expected_output="' Paris'",
1073
+ expected_loss=0.1,
1074
+ )
1075
+ def forward(
1076
+ self,
1077
+ input_ids: Optional[torch.LongTensor] = None,
1078
+ attention_mask: Optional[torch.FloatTensor] = None,
1079
+ token_type_ids: Optional[torch.LongTensor] = None,
1080
+ position_ids: Optional[torch.LongTensor] = None,
1081
+ head_mask: Optional[torch.FloatTensor] = None,
1082
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1083
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1084
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1085
+ labels: Optional[torch.LongTensor] = None,
1086
+ output_attentions: Optional[bool] = None,
1087
+ output_hidden_states: Optional[bool] = None,
1088
+ return_dict: Optional[bool] = None,
1089
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
1090
+ r"""
1091
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1092
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1093
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1094
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1095
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1096
+ Used to hide legacy arguments that have been deprecated.
1097
+ """
1098
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1099
+
1100
+ outputs = self.roberta(
1101
+ input_ids,
1102
+ attention_mask=attention_mask,
1103
+ token_type_ids=token_type_ids,
1104
+ position_ids=position_ids,
1105
+ head_mask=head_mask,
1106
+ inputs_embeds=inputs_embeds,
1107
+ encoder_hidden_states=encoder_hidden_states,
1108
+ encoder_attention_mask=encoder_attention_mask,
1109
+ output_attentions=output_attentions,
1110
+ output_hidden_states=output_hidden_states,
1111
+ return_dict=return_dict,
1112
+ )
1113
+ sequence_output = outputs[0]
1114
+ prediction_scores = self.lm_head(sequence_output)
1115
+
1116
+ masked_lm_loss = None
1117
+ if labels is not None:
1118
+ # move labels to correct device to enable model parallelism
1119
+ labels = labels.to(prediction_scores.device)
1120
+ loss_fct = CrossEntropyLoss()
1121
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1122
+
1123
+ if not return_dict:
1124
+ output = (prediction_scores,) + outputs[2:]
1125
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1126
+
1127
+ return MaskedLMOutput(
1128
+ loss=masked_lm_loss,
1129
+ logits=prediction_scores,
1130
+ hidden_states=outputs.hidden_states,
1131
+ attentions=outputs.attentions,
1132
+ )
1133
+
1134
+
1135
+ class RobertaLMHead(nn.Module):
1136
+ """Roberta Head for masked language modeling."""
1137
+
1138
+ def __init__(self, config):
1139
+ super().__init__()
1140
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1141
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1142
+
1143
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
1144
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1145
+ self.decoder.bias = self.bias
1146
+
1147
+ def forward(self, features, **kwargs):
1148
+ x = self.dense(features)
1149
+ x = gelu(x)
1150
+ x = self.layer_norm(x)
1151
+
1152
+ # project back to size of vocabulary with bias
1153
+ x = self.decoder(x)
1154
+
1155
+ return x
1156
+
1157
+ def _tie_weights(self):
1158
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
1159
+ # For accelerate compatibility and to not break backward compatibility
1160
+ if self.decoder.bias.device.type == "meta":
1161
+ self.decoder.bias = self.bias
1162
+ else:
1163
+ self.bias = self.decoder.bias
1164
+
1165
+
1166
+ @add_start_docstrings(
1167
+ """
1168
+ RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1169
+ pooled output) e.g. for GLUE tasks.
1170
+ """,
1171
+ ROBERTA_START_DOCSTRING,
1172
+ )
1173
+ class RobertaForSequenceClassification(RobertaPreTrainedModel):
1174
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1175
+
1176
+ def __init__(self, config):
1177
+ super().__init__(config)
1178
+ self.num_labels = config.num_labels
1179
+ self.config = config
1180
+
1181
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1182
+ self.classifier = RobertaClassificationHead(config)
1183
+
1184
+ # Initialize weights and apply final processing
1185
+ self.post_init()
1186
+
1187
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1188
+ @add_code_sample_docstrings(
1189
+ checkpoint="cardiffnlp/twitter-roberta-base-emotion",
1190
+ output_type=SequenceClassifierOutput,
1191
+ config_class=_CONFIG_FOR_DOC,
1192
+ expected_output="'optimism'",
1193
+ expected_loss=0.08,
1194
+ )
1195
+ def forward(
1196
+ self,
1197
+ input_ids: Optional[torch.LongTensor] = None,
1198
+ attention_mask: Optional[torch.FloatTensor] = None,
1199
+ token_type_ids: Optional[torch.LongTensor] = None,
1200
+ position_ids: Optional[torch.LongTensor] = None,
1201
+ head_mask: Optional[torch.FloatTensor] = None,
1202
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1203
+ labels: Optional[torch.LongTensor] = None,
1204
+ output_attentions: Optional[bool] = None,
1205
+ output_hidden_states: Optional[bool] = None,
1206
+ return_dict: Optional[bool] = None,
1207
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
1208
+ r"""
1209
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1210
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1211
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1212
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1213
+ """
1214
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1215
+
1216
+ outputs = self.roberta(
1217
+ input_ids,
1218
+ attention_mask=attention_mask,
1219
+ token_type_ids=token_type_ids,
1220
+ position_ids=position_ids,
1221
+ head_mask=head_mask,
1222
+ inputs_embeds=inputs_embeds,
1223
+ output_attentions=output_attentions,
1224
+ output_hidden_states=output_hidden_states,
1225
+ return_dict=return_dict,
1226
+ )
1227
+ sequence_output = outputs[0]
1228
+ logits = self.classifier(sequence_output)
1229
+
1230
+ loss = None
1231
+ if labels is not None:
1232
+ # move labels to correct device to enable model parallelism
1233
+ labels = labels.to(logits.device)
1234
+ if self.config.problem_type is None:
1235
+ if self.num_labels == 1:
1236
+ self.config.problem_type = "regression"
1237
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1238
+ self.config.problem_type = "single_label_classification"
1239
+ else:
1240
+ self.config.problem_type = "multi_label_classification"
1241
+
1242
+ if self.config.problem_type == "regression":
1243
+ loss_fct = MSELoss()
1244
+ if self.num_labels == 1:
1245
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1246
+ else:
1247
+ loss = loss_fct(logits, labels)
1248
+ elif self.config.problem_type == "single_label_classification":
1249
+ loss_fct = CrossEntropyLoss()
1250
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1251
+ elif self.config.problem_type == "multi_label_classification":
1252
+ loss_fct = BCEWithLogitsLoss()
1253
+ loss = loss_fct(logits, labels)
1254
+
1255
+ if not return_dict:
1256
+ output = (logits,) + outputs[2:]
1257
+ return ((loss,) + output) if loss is not None else output
1258
+
1259
+ return SequenceClassifierOutput(
1260
+ loss=loss,
1261
+ logits=logits,
1262
+ hidden_states=outputs.hidden_states,
1263
+ attentions=outputs.attentions,
1264
+ )
1265
+
1266
+
1267
+ @add_start_docstrings(
1268
+ """
1269
+ Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1270
+ softmax) e.g. for RocStories/SWAG tasks.
1271
+ """,
1272
+ ROBERTA_START_DOCSTRING,
1273
+ )
1274
+ class RobertaForMultipleChoice(RobertaPreTrainedModel):
1275
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1276
+
1277
+ def __init__(self, config):
1278
+ super().__init__(config)
1279
+
1280
+ self.roberta = RobertaModel(config)
1281
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1282
+ self.classifier = nn.Linear(config.hidden_size, 1)
1283
+
1284
+ # Initialize weights and apply final processing
1285
+ self.post_init()
1286
+
1287
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
1288
+ @add_code_sample_docstrings(
1289
+ checkpoint=_CHECKPOINT_FOR_DOC,
1290
+ output_type=MultipleChoiceModelOutput,
1291
+ config_class=_CONFIG_FOR_DOC,
1292
+ )
1293
+ def forward(
1294
+ self,
1295
+ input_ids: Optional[torch.LongTensor] = None,
1296
+ token_type_ids: Optional[torch.LongTensor] = None,
1297
+ attention_mask: Optional[torch.FloatTensor] = None,
1298
+ labels: Optional[torch.LongTensor] = None,
1299
+ position_ids: Optional[torch.LongTensor] = None,
1300
+ head_mask: Optional[torch.FloatTensor] = None,
1301
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1302
+ output_attentions: Optional[bool] = None,
1303
+ output_hidden_states: Optional[bool] = None,
1304
+ return_dict: Optional[bool] = None,
1305
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
1306
+ r"""
1307
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1308
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1309
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1310
+ `input_ids` above)
1311
+ """
1312
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1313
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1314
+
1315
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1316
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1317
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1318
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1319
+ flat_inputs_embeds = (
1320
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1321
+ if inputs_embeds is not None
1322
+ else None
1323
+ )
1324
+
1325
+ outputs = self.roberta(
1326
+ flat_input_ids,
1327
+ position_ids=flat_position_ids,
1328
+ token_type_ids=flat_token_type_ids,
1329
+ attention_mask=flat_attention_mask,
1330
+ head_mask=head_mask,
1331
+ inputs_embeds=flat_inputs_embeds,
1332
+ output_attentions=output_attentions,
1333
+ output_hidden_states=output_hidden_states,
1334
+ return_dict=return_dict,
1335
+ )
1336
+ pooled_output = outputs[1]
1337
+
1338
+ pooled_output = self.dropout(pooled_output)
1339
+ logits = self.classifier(pooled_output)
1340
+ reshaped_logits = logits.view(-1, num_choices)
1341
+
1342
+ loss = None
1343
+ if labels is not None:
1344
+ # move labels to correct device to enable model parallelism
1345
+ labels = labels.to(reshaped_logits.device)
1346
+ loss_fct = CrossEntropyLoss()
1347
+ loss = loss_fct(reshaped_logits, labels)
1348
+
1349
+ if not return_dict:
1350
+ output = (reshaped_logits,) + outputs[2:]
1351
+ return ((loss,) + output) if loss is not None else output
1352
+
1353
+ return MultipleChoiceModelOutput(
1354
+ loss=loss,
1355
+ logits=reshaped_logits,
1356
+ hidden_states=outputs.hidden_states,
1357
+ attentions=outputs.attentions,
1358
+ )
1359
+
1360
+
1361
+ @add_start_docstrings(
1362
+ """
1363
+ Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1364
+ Named-Entity-Recognition (NER) tasks.
1365
+ """,
1366
+ ROBERTA_START_DOCSTRING,
1367
+ )
1368
+ class RobertaForTokenClassification(RobertaPreTrainedModel):
1369
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1370
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1371
+
1372
+ def __init__(self, config):
1373
+ super().__init__(config)
1374
+ self.num_labels = config.num_labels
1375
+
1376
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1377
+ classifier_dropout = (
1378
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1379
+ )
1380
+ self.dropout = nn.Dropout(classifier_dropout)
1381
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1382
+
1383
+ # Initialize weights and apply final processing
1384
+ self.post_init()
1385
+
1386
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1387
+ @add_code_sample_docstrings(
1388
+ checkpoint="Jean-Baptiste/roberta-large-ner-english",
1389
+ output_type=TokenClassifierOutput,
1390
+ config_class=_CONFIG_FOR_DOC,
1391
+ expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']",
1392
+ expected_loss=0.01,
1393
+ )
1394
+ def forward(
1395
+ self,
1396
+ input_ids: Optional[torch.LongTensor] = None,
1397
+ attention_mask: Optional[torch.FloatTensor] = None,
1398
+ token_type_ids: Optional[torch.LongTensor] = None,
1399
+ position_ids: Optional[torch.LongTensor] = None,
1400
+ head_mask: Optional[torch.FloatTensor] = None,
1401
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1402
+ labels: Optional[torch.LongTensor] = None,
1403
+ output_attentions: Optional[bool] = None,
1404
+ output_hidden_states: Optional[bool] = None,
1405
+ return_dict: Optional[bool] = None,
1406
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1407
+ r"""
1408
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1409
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1410
+ """
1411
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1412
+
1413
+ outputs = self.roberta(
1414
+ input_ids,
1415
+ attention_mask=attention_mask,
1416
+ token_type_ids=token_type_ids,
1417
+ position_ids=position_ids,
1418
+ head_mask=head_mask,
1419
+ inputs_embeds=inputs_embeds,
1420
+ output_attentions=output_attentions,
1421
+ output_hidden_states=output_hidden_states,
1422
+ return_dict=return_dict,
1423
+ )
1424
+
1425
+ sequence_output = outputs[0]
1426
+
1427
+ sequence_output = self.dropout(sequence_output)
1428
+ logits = self.classifier(sequence_output)
1429
+
1430
+ loss = None
1431
+ if labels is not None:
1432
+ # move labels to correct device to enable model parallelism
1433
+ labels = labels.to(logits.device)
1434
+ loss_fct = CrossEntropyLoss()
1435
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1436
+
1437
+ if not return_dict:
1438
+ output = (logits,) + outputs[2:]
1439
+ return ((loss,) + output) if loss is not None else output
1440
+
1441
+ return TokenClassifierOutput(
1442
+ loss=loss,
1443
+ logits=logits,
1444
+ hidden_states=outputs.hidden_states,
1445
+ attentions=outputs.attentions,
1446
+ )
1447
+
1448
+
1449
+ class RobertaClassificationHead(nn.Module):
1450
+ """Head for sentence-level classification tasks."""
1451
+
1452
+ def __init__(self, config):
1453
+ super().__init__()
1454
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1455
+ classifier_dropout = (
1456
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
1457
+ )
1458
+ self.dropout = nn.Dropout(classifier_dropout)
1459
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1460
+
1461
+ def forward(self, features, **kwargs):
1462
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1463
+ x = self.dropout(x)
1464
+ x = self.dense(x)
1465
+ x = torch.tanh(x)
1466
+ x = self.dropout(x)
1467
+ x = self.out_proj(x)
1468
+ return x
1469
+
1470
+
1471
+ @add_start_docstrings(
1472
+ """
1473
+ Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1474
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1475
+ """,
1476
+ ROBERTA_START_DOCSTRING,
1477
+ )
1478
+ class RobertaForQuestionAnswering(RobertaPreTrainedModel):
1479
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1480
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
1481
+
1482
+ def __init__(self, config):
1483
+ super().__init__(config)
1484
+ self.num_labels = config.num_labels
1485
+
1486
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1487
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1488
+
1489
+ # Initialize weights and apply final processing
1490
+ self.post_init()
1491
+
1492
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1493
+ @add_code_sample_docstrings(
1494
+ checkpoint="deepset/roberta-base-squad2",
1495
+ output_type=QuestionAnsweringModelOutput,
1496
+ config_class=_CONFIG_FOR_DOC,
1497
+ expected_output="' puppet'",
1498
+ expected_loss=0.86,
1499
+ )
1500
+ def forward(
1501
+ self,
1502
+ input_ids: Optional[torch.LongTensor] = None,
1503
+ attention_mask: Optional[torch.FloatTensor] = None,
1504
+ token_type_ids: Optional[torch.LongTensor] = None,
1505
+ position_ids: Optional[torch.LongTensor] = None,
1506
+ head_mask: Optional[torch.FloatTensor] = None,
1507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1508
+ start_positions: Optional[torch.LongTensor] = None,
1509
+ end_positions: Optional[torch.LongTensor] = None,
1510
+ output_attentions: Optional[bool] = None,
1511
+ output_hidden_states: Optional[bool] = None,
1512
+ return_dict: Optional[bool] = None,
1513
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1514
+ r"""
1515
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1516
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1517
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1518
+ are not taken into account for computing the loss.
1519
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1520
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1521
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1522
+ are not taken into account for computing the loss.
1523
+ """
1524
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1525
+
1526
+ outputs = self.roberta(
1527
+ input_ids,
1528
+ attention_mask=attention_mask,
1529
+ token_type_ids=token_type_ids,
1530
+ position_ids=position_ids,
1531
+ head_mask=head_mask,
1532
+ inputs_embeds=inputs_embeds,
1533
+ output_attentions=output_attentions,
1534
+ output_hidden_states=output_hidden_states,
1535
+ return_dict=return_dict,
1536
+ )
1537
+
1538
+ sequence_output = outputs[0]
1539
+
1540
+ logits = self.qa_outputs(sequence_output)
1541
+ start_logits, end_logits = logits.split(1, dim=-1)
1542
+ start_logits = start_logits.squeeze(-1).contiguous()
1543
+ end_logits = end_logits.squeeze(-1).contiguous()
1544
+
1545
+ total_loss = None
1546
+ if start_positions is not None and end_positions is not None:
1547
+ # If we are on multi-GPU, split add a dimension
1548
+ if len(start_positions.size()) > 1:
1549
+ start_positions = start_positions.squeeze(-1)
1550
+ if len(end_positions.size()) > 1:
1551
+ end_positions = end_positions.squeeze(-1)
1552
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1553
+ ignored_index = start_logits.size(1)
1554
+ start_positions = start_positions.clamp(0, ignored_index)
1555
+ end_positions = end_positions.clamp(0, ignored_index)
1556
+
1557
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1558
+ start_loss = loss_fct(start_logits, start_positions)
1559
+ end_loss = loss_fct(end_logits, end_positions)
1560
+ total_loss = (start_loss + end_loss) / 2
1561
+
1562
+ if not return_dict:
1563
+ output = (start_logits, end_logits) + outputs[2:]
1564
+ return ((total_loss,) + output) if total_loss is not None else output
1565
+
1566
+ return QuestionAnsweringModelOutput(
1567
+ loss=total_loss,
1568
+ start_logits=start_logits,
1569
+ end_logits=end_logits,
1570
+ hidden_states=outputs.hidden_states,
1571
+ attentions=outputs.attentions,
1572
+ )
1573
+
1574
+
1575
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
1576
+ """
1577
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
1578
+ are ignored. This is modified from fairseq's `utils.make_positions`.
1579
+
1580
+ Args:
1581
+ x: torch.Tensor x:
1582
+
1583
+ Returns: torch.Tensor
1584
+ """
1585
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
1586
+ mask = input_ids.ne(padding_idx).int()
1587
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
1588
+ return incremental_indices.long() + padding_idx
soft_prompt/model/sequence_causallm.py ADDED
@@ -0,0 +1,1249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._C import NoopLogger
3
+ import torch.nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from typing import List, Optional, Tuple, Union
7
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
8
+
9
+ from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel, BertOnlyMLMHead
10
+ from transformers.models.opt.modeling_opt import OPTModel, OPTPreTrainedModel
11
+ from transformers.models.roberta.modeling_roberta import RobertaLMHead, RobertaModel, RobertaPreTrainedModel
12
+ from transformers.models.llama.modeling_llama import LlamaPreTrainedModel, LlamaModel, CausalLMOutputWithPast
13
+ from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel
14
+ from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast, BaseModelOutput, Seq2SeqLMOutput
15
+ from .prefix_encoder import PrefixEncoder
16
+ from . import utils
17
+ import hashlib
18
+
19
+
20
+ def hash_nn(model):
21
+ md5 = hashlib.md5() # ignore
22
+ for arg in model.parameters():
23
+ x = arg.data
24
+ if hasattr(x, "cpu"):
25
+ md5.update(x.cpu().numpy().data.tobytes())
26
+ elif hasattr(x, "numpy"):
27
+ md5.update(x.numpy().data.tobytes())
28
+ elif hasattr(x, "data"):
29
+ md5.update(x.data.tobytes())
30
+ else:
31
+ try:
32
+ md5.update(x.encode("utf-8"))
33
+ except:
34
+ md5.update(str(x).encode("utf-8"))
35
+ return md5.hexdigest()
36
+
37
+
38
+ class OPTPrefixForMaskedLM(OPTPreTrainedModel):
39
+ _tied_weights_keys = ["lm_head.weight"]
40
+ def __init__(self, config):
41
+ super().__init__(config)
42
+ self.model = OPTModel(config)
43
+ self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
44
+ self.dropout = torch.nn.Dropout(0.1)
45
+ for param in self.model.parameters():
46
+ param.requires_grad = False
47
+
48
+ self.pre_seq_len = config.pre_seq_len
49
+ self.n_layer = config.num_hidden_layers
50
+ self.n_head = config.num_attention_heads
51
+ self.n_embd = config.hidden_size // config.num_attention_heads
52
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
53
+ self.prefix_encoder = PrefixEncoder(config)
54
+
55
+ base_param = 0
56
+ for name, param in self.model.named_parameters():
57
+ base_param += param.numel()
58
+ all_param = 0
59
+ for name, param in self.named_parameters():
60
+ all_param += param.numel()
61
+ total_param = all_param - base_param
62
+ print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param))
63
+
64
+ self.embedding = self.get_input_embeddings()
65
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
66
+ self.clean_labels = torch.tensor(config.clean_labels).long()
67
+
68
+ def get_input_embeddings(self):
69
+ return self.model.decoder.embed_tokens
70
+
71
+ def set_input_embeddings(self, value):
72
+ self.model.decoder.embed_tokens = value
73
+
74
+ def get_output_embeddings(self):
75
+ return self.lm_head
76
+
77
+ def set_output_embeddings(self, new_embeddings):
78
+ self.lm_head = new_embeddings
79
+
80
+ def set_decoder(self, decoder):
81
+ self.model.decoder = decoder
82
+
83
+ def get_decoder(self):
84
+ return self.model.decoder
85
+
86
+ def get_prompt(self, batch_size):
87
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device)
88
+ past_key_values = self.prefix_encoder(prefix_tokens)
89
+ # bsz, seqlen, _ = past_key_values.shape
90
+ past_key_values = past_key_values.view(
91
+ batch_size,
92
+ self.pre_seq_len,
93
+ self.n_layer * 2,
94
+ self.n_head,
95
+ self.n_embd
96
+ )
97
+ past_key_values = self.dropout(past_key_values)
98
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
99
+ return past_key_values
100
+
101
+ def use_grad(self, transformer, use_grad):
102
+ if use_grad:
103
+ for param in transformer.parameters():
104
+ param.requires_grad = True
105
+ transformer.train()
106
+ else:
107
+ for param in transformer.parameters():
108
+ param.requires_grad = False
109
+ transformer.eval()
110
+ for param in self.lm_head.parameters():
111
+ param.requires_grad = True
112
+ for param in self.prefix_encoder.parameters():
113
+ param.requires_grad = True
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: torch.LongTensor = None,
118
+ attention_mask: Optional[torch.Tensor] = None,
119
+ head_mask: Optional[torch.Tensor] = None,
120
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
121
+ inputs_embeds: Optional[torch.FloatTensor] = None,
122
+ labels: Optional[torch.LongTensor] = None,
123
+ token_labels: Optional[torch.LongTensor] = None,
124
+ use_cache: Optional[bool] = None,
125
+ output_attentions: Optional[bool] = None,
126
+ output_hidden_states: Optional[bool] = None,
127
+ return_dict: Optional[bool] = None,
128
+ use_base_grad=False,
129
+ ):
130
+ r"""
131
+ Args:
132
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
133
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
134
+ provide it.
135
+
136
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
137
+ [`PreTrainedTokenizer.__call__`] for details.
138
+
139
+ [What are input IDs?](../glossary#input-ids)
140
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
141
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
142
+
143
+ - 1 for tokens that are **not masked**,
144
+ - 0 for tokens that are **masked**.
145
+
146
+ [What are attention masks?](../glossary#attention-mask)
147
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
148
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
149
+
150
+ - 1 indicates the head is **not masked**,
151
+ - 0 indicates the head is **masked**.
152
+
153
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
154
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
155
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
156
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
157
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
158
+
159
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
160
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
161
+
162
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
163
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
164
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
165
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
166
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
167
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
168
+ than the model's internal embedding lookup matrix.
169
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
170
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
171
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
172
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
173
+ use_cache (`bool`, *optional*):
174
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
175
+ (see `past_key_values`).
176
+ output_attentions (`bool`, *optional*):
177
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
178
+ returned tensors for more detail.
179
+ output_hidden_states (`bool`, *optional*):
180
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
181
+ for more detail.
182
+ return_dict (`bool`, *optional*):
183
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
184
+
185
+ Returns:
186
+
187
+ Example:
188
+
189
+ ```python
190
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
191
+
192
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
193
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
194
+
195
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
196
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
197
+
198
+ >>> # Generate
199
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
200
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
201
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
202
+ ```"""
203
+ utils.use_grad(self.model, use_base_grad)
204
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
205
+ output_hidden_states = (
206
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
207
+ )
208
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
209
+ batch_size = input_ids.shape[0]
210
+ past_key_values = self.get_prompt(batch_size=batch_size)
211
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device)
212
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
213
+
214
+ outputs = self.model.decoder(
215
+ input_ids=input_ids,
216
+ attention_mask=attention_mask,
217
+ inputs_embeds=inputs_embeds,
218
+ use_cache=use_cache,
219
+ output_attentions=output_attentions,
220
+ output_hidden_states=output_hidden_states,
221
+ return_dict=return_dict,
222
+ past_key_values=past_key_values,
223
+ )
224
+ sequence_output = outputs[0]
225
+ sequence_output = self.dropout(sequence_output)
226
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device)
227
+ cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous()
228
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous()
229
+
230
+ # compute loss
231
+ masked_lm_loss = None
232
+ if token_labels is not None:
233
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
234
+ else:
235
+ if labels is not None:
236
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
237
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
238
+
239
+ # convert to binary classifier
240
+ probs = []
241
+ for y in self.clean_labels:
242
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
243
+ logits = torch.stack(probs).T
244
+
245
+ return SequenceClassifierOutput(
246
+ loss=masked_lm_loss,
247
+ logits=logits,
248
+ hidden_states=outputs.hidden_states,
249
+ attentions=attentions
250
+ )
251
+
252
+ def prepare_inputs_for_generation(
253
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
254
+ ):
255
+ if past_key_values:
256
+ input_ids = input_ids[:, -1:]
257
+
258
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
259
+ if inputs_embeds is not None and past_key_values is None:
260
+ model_inputs = {"inputs_embeds": inputs_embeds}
261
+ else:
262
+ model_inputs = {"input_ids": input_ids}
263
+
264
+ model_inputs.update(
265
+ {
266
+ "past_key_values": past_key_values,
267
+ "use_cache": kwargs.get("use_cache"),
268
+ "attention_mask": attention_mask,
269
+ }
270
+ )
271
+ return model_inputs
272
+
273
+ @staticmethod
274
+ def _reorder_cache(past_key_values, beam_idx):
275
+ reordered_past = ()
276
+ for layer_past in past_key_values:
277
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
278
+ return reordered_past
279
+
280
+
281
+ class OPTPromptForMaskedLM(OPTPreTrainedModel):
282
+ _tied_weights_keys = ["lm_head.weight"]
283
+
284
+ def __init__(self, config):
285
+ super().__init__(config)
286
+ self.num_labels = config.num_labels
287
+ self.model = OPTModel(config)
288
+ self.score = torch.nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
289
+ self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
290
+ self.dropout = torch.nn.Dropout(0.1)
291
+ for param in self.model.parameters():
292
+ param.requires_grad = False
293
+
294
+ self.pre_seq_len = config.pre_seq_len
295
+ self.n_layer = config.num_hidden_layers
296
+ self.n_head = config.num_attention_heads
297
+ self.n_embd = config.hidden_size // config.num_attention_heads
298
+
299
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
300
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
301
+
302
+ model_param = 0
303
+ for name, param in self.model.named_parameters():
304
+ model_param += param.numel()
305
+ all_param = 0
306
+ for name, param in self.named_parameters():
307
+ all_param += param.numel()
308
+ total_param = all_param - model_param
309
+ print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(model_param / 1000000, total_param))
310
+
311
+ self.embedding = self.model.decoder.embed_tokens
312
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
313
+ self.clean_labels = torch.tensor(config.clean_labels).long()
314
+
315
+ def get_input_embeddings(self):
316
+ return self.model.decoder.embed_tokens
317
+
318
+ def set_input_embeddings(self, value):
319
+ self.model.decoder.embed_tokens = value
320
+
321
+ def get_output_embeddings(self):
322
+ return self.lm_head
323
+
324
+ def set_output_embeddings(self, new_embeddings):
325
+ self.lm_head = new_embeddings
326
+
327
+ def set_decoder(self, decoder):
328
+ self.model.decoder = decoder
329
+
330
+ def get_decoder(self):
331
+ return self.model.decoder
332
+
333
+ def get_prompt(self, batch_size):
334
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device)
335
+ prompts = self.prefix_encoder(prefix_tokens)
336
+ return prompts
337
+
338
+ def use_grad(self, transformer, use_grad):
339
+ if use_grad:
340
+ for param in transformer.parameters():
341
+ param.requires_grad = True
342
+ transformer.train()
343
+ else:
344
+ for param in transformer.parameters():
345
+ param.requires_grad = False
346
+ transformer.eval()
347
+ for param in self.lm_head.parameters():
348
+ param.requires_grad = True
349
+ for param in self.prefix_encoder.parameters():
350
+ param.requires_grad = True
351
+
352
+ def forward(
353
+ self,
354
+ input_ids: torch.LongTensor = None,
355
+ attention_mask: Optional[torch.Tensor] = None,
356
+ head_mask: Optional[torch.Tensor] = None,
357
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
358
+ inputs_embeds: Optional[torch.FloatTensor] = None,
359
+ labels: Optional[torch.LongTensor] = None,
360
+ token_labels: Optional[torch.LongTensor] = None,
361
+ use_cache: Optional[bool] = None,
362
+ output_attentions: Optional[bool] = None,
363
+ output_hidden_states: Optional[bool] = None,
364
+ return_dict: Optional[bool] = None,
365
+ use_base_grad=False,
366
+ ):
367
+ r"""
368
+ Args:
369
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
370
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
371
+ provide it.
372
+
373
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
374
+ [`PreTrainedTokenizer.__call__`] for details.
375
+
376
+ [What are input IDs?](../glossary#input-ids)
377
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
378
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
379
+
380
+ - 1 for tokens that are **not masked**,
381
+ - 0 for tokens that are **masked**.
382
+
383
+ [What are attention masks?](../glossary#attention-mask)
384
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
385
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
386
+
387
+ - 1 indicates the head is **not masked**,
388
+ - 0 indicates the head is **masked**.
389
+
390
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
391
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
392
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
393
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
394
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
395
+
396
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
397
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
398
+
399
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
400
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
401
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
402
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
403
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
404
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
405
+ than the model's internal embedding lookup matrix.
406
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
407
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
408
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
409
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
410
+ use_cache (`bool`, *optional*):
411
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
412
+ (see `past_key_values`).
413
+ output_attentions (`bool`, *optional*):
414
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
415
+ returned tensors for more detail.
416
+ output_hidden_states (`bool`, *optional*):
417
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
418
+ for more detail.
419
+ return_dict (`bool`, *optional*):
420
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
421
+
422
+ Returns:
423
+
424
+ Example:
425
+
426
+ ```python
427
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
428
+
429
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
430
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
431
+
432
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
433
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
434
+
435
+ >>> # Generate
436
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
437
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
438
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
439
+ ```"""
440
+ utils.use_grad(self.model, use_base_grad)
441
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
442
+ output_hidden_states = (
443
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
444
+ )
445
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
446
+
447
+ batch_size = input_ids.shape[0]
448
+ raw_embedding = self.model.decoder.embed_tokens(input_ids)
449
+ prompts = self.get_prompt(batch_size=batch_size)
450
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
451
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device)
452
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
453
+
454
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
455
+ outputs = self.model.decoder(
456
+ attention_mask=attention_mask,
457
+ inputs_embeds=inputs_embeds,
458
+ use_cache=use_cache,
459
+ output_attentions=output_attentions,
460
+ output_hidden_states=output_hidden_states,
461
+ return_dict=return_dict,
462
+ )
463
+ sequence_output = outputs[0]
464
+ sequence_output = sequence_output[:, self.pre_seq_len:, :]
465
+ sequence_output = self.dropout(sequence_output)
466
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device)
467
+ cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous()
468
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous()
469
+
470
+ # compute loss
471
+ loss = None
472
+ if token_labels is not None:
473
+ loss = utils.get_loss(attentions, token_labels).sum()
474
+ else:
475
+ if labels is not None:
476
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
477
+ loss = utils.get_loss(attentions, token_labels).sum()
478
+
479
+ # convert to binary classifier
480
+ probs = []
481
+ for idx, y in enumerate(self.clean_labels):
482
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
483
+ logits = torch.stack(probs).T
484
+ #loss = torch.nn.functional.nll_loss(logits, labels)
485
+
486
+ return SequenceClassifierOutput(
487
+ loss=loss,
488
+ logits=logits,
489
+ hidden_states=outputs.hidden_states,
490
+ attentions=attentions
491
+ )
492
+
493
+ def prepare_inputs_for_generation(
494
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
495
+ ):
496
+ if past_key_values:
497
+ input_ids = input_ids[:, -1:]
498
+
499
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
500
+ if inputs_embeds is not None and past_key_values is None:
501
+ model_inputs = {"inputs_embeds": inputs_embeds}
502
+ else:
503
+ model_inputs = {"input_ids": input_ids}
504
+
505
+ model_inputs.update(
506
+ {
507
+ "past_key_values": past_key_values,
508
+ "use_cache": kwargs.get("use_cache"),
509
+ "attention_mask": attention_mask,
510
+ }
511
+ )
512
+ return model_inputs
513
+
514
+ @staticmethod
515
+ def _reorder_cache(past_key_values, beam_idx):
516
+ reordered_past = ()
517
+ for layer_past in past_key_values:
518
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
519
+ return reordered_past
520
+
521
+
522
+ class LlamaPrefixForMaskedLM(LlamaPreTrainedModel):
523
+ _tied_weights_keys = ["lm_head.weight"]
524
+
525
+ def __init__(self, config):
526
+ super().__init__(config)
527
+ self.model = LlamaModel(config)
528
+ self.vocab_size = config.vocab_size
529
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
530
+ self.dropout = torch.nn.Dropout(0.1)
531
+ for param in self.model.parameters():
532
+ param.requires_grad = False
533
+
534
+ self.pre_seq_len = config.pre_seq_len
535
+ self.n_layer = config.num_hidden_layers
536
+ self.n_head = config.num_attention_heads
537
+ self.n_embd = config.hidden_size // config.num_attention_heads
538
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
539
+ self.prefix_encoder = PrefixEncoder(config)
540
+
541
+ base_param = 0
542
+ for name, param in self.model.named_parameters():
543
+ base_param += param.numel()
544
+ all_param = 0
545
+ for name, param in self.named_parameters():
546
+ all_param += param.numel()
547
+ total_param = all_param - base_param
548
+ print('-> LLama_param:{:0.2f}M P-tuning-V2 param:{:0.2f}M'.format(base_param / 1000000, total_param/ 1000000))
549
+
550
+ self.embedding = self.model.embed_tokens
551
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
552
+ self.clean_labels = torch.tensor(config.clean_labels).long()
553
+
554
+ def get_prompt(self, batch_size):
555
+ device = next(self.prefix_encoder.parameters()).device
556
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
557
+ past_key_values = self.prefix_encoder(prefix_tokens)
558
+ # bsz, seqlen, _ = past_key_values.shape
559
+ past_key_values = past_key_values.view(
560
+ batch_size,
561
+ self.pre_seq_len,
562
+ self.n_layer * 2,
563
+ self.n_head,
564
+ self.n_embd
565
+ )
566
+ past_key_values = self.dropout(past_key_values)
567
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
568
+ return past_key_values
569
+
570
+ def get_input_embeddings(self):
571
+ return self.model.embed_tokens
572
+
573
+ def set_input_embeddings(self, value):
574
+ self.model.embed_tokens = value
575
+
576
+ def get_output_embeddings(self):
577
+ return self.lm_head
578
+
579
+ def set_output_embeddings(self, new_embeddings):
580
+ self.lm_head = new_embeddings
581
+
582
+ def set_decoder(self, decoder):
583
+ self.model = decoder
584
+
585
+ def get_decoder(self):
586
+ return self.model
587
+
588
+ def use_grad(self, base_model, use_grad):
589
+ if use_grad:
590
+ for param in base_model.parameters():
591
+ param.requires_grad = True
592
+ base_model.train()
593
+ else:
594
+ for param in base_model.parameters():
595
+ param.requires_grad = False
596
+ base_model.eval()
597
+ for param in self.prefix_encoder.parameters():
598
+ param.requires_grad = True
599
+ for param in self.lm_head.parameters():
600
+ param.requires_grad = True
601
+
602
+ def forward(
603
+ self,
604
+ input_ids=None,
605
+ attention_mask=None,
606
+ token_type_ids=None,
607
+ position_ids=None,
608
+ inputs_embeds=None,
609
+ labels=None,
610
+ token_labels=None,
611
+ output_attentions=None,
612
+ output_hidden_states=None,
613
+ return_dict=None,
614
+ use_base_grad=False,
615
+ ):
616
+ utils.use_grad(self.model, use_base_grad)
617
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
618
+ batch_size = input_ids.shape[0]
619
+ past_key_values = self.get_prompt(batch_size=batch_size)
620
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device)
621
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
622
+
623
+ outputs = self.model(
624
+ input_ids=input_ids,
625
+ attention_mask=attention_mask,
626
+ position_ids=position_ids,
627
+ inputs_embeds=inputs_embeds,
628
+ output_attentions=output_attentions,
629
+ output_hidden_states=output_hidden_states,
630
+ return_dict=return_dict,
631
+ past_key_values=past_key_values,
632
+ )
633
+ sequence_output = outputs[0]
634
+ sequence_output = self.dropout(sequence_output)
635
+ #sequence_output = torch.clamp(sequence_output, min=-1, max=1)
636
+ #cls_token = sequence_output[:, :1]
637
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(sequence_output.device)
638
+ cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous()
639
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous()
640
+
641
+ # compute loss
642
+ masked_lm_loss = None
643
+ if token_labels is not None:
644
+ masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum()
645
+ else:
646
+ if labels is not None:
647
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
648
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
649
+
650
+ # convert to binary classifier
651
+ probs = []
652
+ for y in self.clean_labels:
653
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
654
+ logits = torch.stack(probs).T
655
+
656
+ return SequenceClassifierOutput(
657
+ loss=masked_lm_loss,
658
+ logits=logits,
659
+ hidden_states=outputs.hidden_states,
660
+ attentions=attentions
661
+ )
662
+
663
+
664
+ class LlamaPromptForMaskedLM(LlamaPreTrainedModel):
665
+ _tied_weights_keys = ["lm_head.weight"]
666
+ def __init__(self, config):
667
+ super().__init__(config)
668
+ self.model = LlamaModel(config)
669
+ self.vocab_size = config.vocab_size
670
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
671
+ self.dropout = torch.nn.Dropout(0.1)
672
+ for param in self.model.parameters():
673
+ param.requires_grad = False
674
+
675
+ self.pre_seq_len = config.pre_seq_len
676
+ self.n_layer = config.num_hidden_layers
677
+ self.n_head = config.num_attention_heads
678
+ self.n_embd = config.hidden_size // config.num_attention_heads
679
+
680
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
681
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
682
+
683
+ model_param = 0
684
+ for name, param in self.model.named_parameters():
685
+ model_param += param.numel()
686
+ all_param = 0
687
+ for name, param in self.named_parameters():
688
+ all_param += param.numel()
689
+ total_param = all_param - model_param
690
+ print('-> Llama_param:{:0.2f}M P-tuning-V2 param is {:0.2f}M'.format(model_param / 1000000, total_param / 1000000))
691
+
692
+ self.pad_token_id = 2
693
+ self.embedding = self.model.embed_tokens
694
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
695
+ self.clean_labels = torch.tensor(config.clean_labels).long()
696
+
697
+ def get_prompt(self, batch_size):
698
+ device = next(self.prefix_encoder.parameters()).device
699
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
700
+ prompts = self.prefix_encoder(prefix_tokens)
701
+ return prompts
702
+
703
+ def get_input_embeddings(self):
704
+ return self.model.embed_tokens
705
+
706
+ def set_input_embeddings(self, value):
707
+ self.model.embed_tokens = value
708
+
709
+ def get_output_embeddings(self):
710
+ return self.lm_head
711
+
712
+ def set_output_embeddings(self, new_embeddings):
713
+ self.lm_head = new_embeddings
714
+
715
+ def set_decoder(self, decoder):
716
+ self.model = decoder
717
+
718
+ def get_decoder(self):
719
+ return self.model
720
+
721
+ def use_grad(self, base_model, use_grad):
722
+ if use_grad:
723
+ for param in base_model.parameters():
724
+ param.requires_grad = True
725
+ for param in self.lm_head.parameters():
726
+ param.requires_grad = True
727
+ base_model.train()
728
+ else:
729
+ for param in base_model.parameters():
730
+ param.requires_grad = False
731
+ for param in self.lm_head.parameters():
732
+ param.requires_grad = False
733
+ base_model.eval()
734
+ for param in self.prefix_encoder.parameters():
735
+ param.requires_grad = True
736
+
737
+
738
+ def forward(
739
+ self,
740
+ input_ids: torch.LongTensor = None,
741
+ attention_mask: Optional[torch.Tensor] = None,
742
+ position_ids: Optional[torch.LongTensor] = None,
743
+ token_type_ids: Optional[torch.LongTensor] =None,
744
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
745
+ inputs_embeds: Optional[torch.FloatTensor] = None,
746
+ head_mask: Optional[torch.FloatTensor] = None,
747
+ labels: Optional[torch.LongTensor] = None,
748
+ token_labels: Optional[torch.LongTensor] = None,
749
+ use_cache: Optional[bool] = None,
750
+ output_attentions: Optional[bool] = None,
751
+ output_hidden_states: Optional[bool] = None,
752
+ return_dict: Optional[bool] = None,
753
+ use_base_grad: Optional[bool] = False,
754
+ ):
755
+ self.use_grad(self.model, use_base_grad)
756
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
757
+
758
+ batch_size = input_ids.shape[0]
759
+ raw_embedding = self.model.embed_tokens(input_ids)
760
+ prompts = self.get_prompt(batch_size=batch_size)
761
+ inputs_embeds = torch.cat((prompts, raw_embedding.to(prompts.device)), dim=1)
762
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device)
763
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
764
+
765
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
766
+ outputs = self.model(
767
+ attention_mask=attention_mask,
768
+ past_key_values=past_key_values,
769
+ inputs_embeds=inputs_embeds,
770
+ use_cache=use_cache,
771
+ output_attentions=output_attentions,
772
+ output_hidden_states=output_hidden_states,
773
+ return_dict=return_dict,
774
+ )
775
+ sequence_output = outputs[0]
776
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
777
+ #cls_token = sequence_output[:, 0]
778
+ #cls_token = self.dropout(cls_token)
779
+ sequence_lengths = (torch.ne(input_ids, self.pad_token_id).sum(-1) - 1).to(sequence_output.device)
780
+ cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous()
781
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous().float()
782
+
783
+ # compute loss
784
+ masked_lm_loss = None
785
+ if token_labels is not None:
786
+ masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum()
787
+ else:
788
+ if labels is not None:
789
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
790
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
791
+
792
+ # convert to binary classifier
793
+ probs = []
794
+ for y in self.clean_labels:
795
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
796
+ logits = torch.stack(probs).T
797
+
798
+ return SequenceClassifierOutput(
799
+ loss=masked_lm_loss,
800
+ logits=logits,
801
+ hidden_states=outputs.hidden_states,
802
+ attentions=attentions
803
+ )
804
+
805
+
806
+ class BertPrefixForMaskedLM(BertPreTrainedModel):
807
+ def __init__(self, config):
808
+ super().__init__(config)
809
+ self.num_labels = config.num_labels
810
+ self.config = config
811
+ self.bert = BertModel(config, add_pooling_layer=False)
812
+ self.cls = BertOnlyMLMHead(config)
813
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
814
+ for param in self.bert.parameters():
815
+ param.requires_grad = False
816
+
817
+ self.pre_seq_len = config.pre_seq_len
818
+ self.n_layer = config.num_hidden_layers
819
+ self.n_head = config.num_attention_heads
820
+ self.n_embd = config.hidden_size // config.num_attention_heads
821
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
822
+ self.prefix_encoder = PrefixEncoder(config)
823
+
824
+ base_param = 0
825
+ for name, param in self.bert.named_parameters():
826
+ base_param += param.numel()
827
+ all_param = 0
828
+ for name, param in self.named_parameters():
829
+ all_param += param.numel()
830
+ total_param = all_param - base_param
831
+ print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param))
832
+
833
+ # bert.embeddings.word_embeddings
834
+ self.embedding = utils.get_embeddings(self, config)
835
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
836
+ self.clean_labels = torch.tensor(config.clean_labels).long()
837
+
838
+ def get_prompt(self, batch_size):
839
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
840
+ past_key_values = self.prefix_encoder(prefix_tokens)
841
+ # bsz, seqlen, _ = past_key_values.shape
842
+ past_key_values = past_key_values.view(
843
+ batch_size,
844
+ self.pre_seq_len,
845
+ self.n_layer * 2,
846
+ self.n_head,
847
+ self.n_embd
848
+ )
849
+ past_key_values = self.dropout(past_key_values)
850
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
851
+ return past_key_values
852
+
853
+ def forward(
854
+ self,
855
+ input_ids=None,
856
+ attention_mask=None,
857
+ token_type_ids=None,
858
+ position_ids=None,
859
+ head_mask=None,
860
+ inputs_embeds=None,
861
+ labels=None,
862
+ token_labels=None,
863
+ output_attentions=None,
864
+ output_hidden_states=None,
865
+ return_dict=None,
866
+ use_base_grad=False,
867
+ ):
868
+ utils.use_grad(self.bert, use_base_grad)
869
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
870
+ batch_size = input_ids.shape[0]
871
+ past_key_values = self.get_prompt(batch_size=batch_size)
872
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
873
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
874
+
875
+ outputs = self.bert(
876
+ input_ids,
877
+ attention_mask=attention_mask,
878
+ token_type_ids=token_type_ids,
879
+ position_ids=position_ids,
880
+ head_mask=head_mask,
881
+ inputs_embeds=inputs_embeds,
882
+ output_attentions=output_attentions,
883
+ output_hidden_states=output_hidden_states,
884
+ return_dict=return_dict,
885
+ past_key_values=past_key_values,
886
+ )
887
+ sequence_output = outputs[0]
888
+ cls_token = sequence_output[:, 0]
889
+ cls_token = self.dropout(cls_token)
890
+ attentions = self.cls(cls_token).view(-1, self.config.vocab_size)
891
+
892
+
893
+ # compute loss
894
+ masked_lm_loss = None
895
+ if token_labels is not None:
896
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
897
+ else:
898
+ if labels is not None:
899
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
900
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
901
+
902
+ # convert to binary classifier
903
+ probs = []
904
+ for y in self.clean_labels:
905
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
906
+ logits = torch.stack(probs).T
907
+
908
+ if not return_dict:
909
+ output = (logits,) + outputs[2:]
910
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
911
+
912
+ return SequenceClassifierOutput(
913
+ loss=masked_lm_loss,
914
+ logits=logits,
915
+ hidden_states=outputs.hidden_states,
916
+ attentions=attentions
917
+ )
918
+
919
+
920
+ class BertPromptForMaskedLM(BertPreTrainedModel):
921
+ def __init__(self, config):
922
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
923
+ super().__init__(config)
924
+ self.num_labels = config.num_labels
925
+ self.config = config
926
+
927
+ self.bert = BertModel(config, add_pooling_layer=False)
928
+ self.cls = BertOnlyMLMHead(config)
929
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
930
+ for param in self.bert.parameters():
931
+ param.requires_grad = False
932
+
933
+ self.pre_seq_len = config.pre_seq_len
934
+ self.n_layer = config.num_hidden_layers
935
+ self.n_head = config.num_attention_heads
936
+ self.n_embd = config.hidden_size // config.num_attention_heads
937
+
938
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
939
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
940
+
941
+ bert_param = 0
942
+ for name, param in self.bert.named_parameters():
943
+ bert_param += param.numel()
944
+ all_param = 0
945
+ for name, param in self.named_parameters():
946
+ all_param += param.numel()
947
+ total_param = all_param - bert_param
948
+ print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(bert_param / 1000000, total_param))
949
+
950
+ # bert.embeddings.word_embeddings
951
+ self.embedding = utils.get_embeddings(self, config)
952
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
953
+ self.clean_labels = torch.tensor(config.clean_labels).long()
954
+
955
+ def get_prompt(self, batch_size):
956
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
957
+ prompts = self.prefix_encoder(prefix_tokens)
958
+ return prompts
959
+
960
+ def forward(
961
+ self,
962
+ input_ids=None,
963
+ attention_mask=None,
964
+ token_type_ids=None,
965
+ position_ids=None,
966
+ head_mask=None,
967
+ inputs_embeds=None,
968
+ labels=None,
969
+ token_labels=None,
970
+ output_attentions=None,
971
+ output_hidden_states=None,
972
+ return_dict=None,
973
+ use_base_grad=False,
974
+ ):
975
+ r"""
976
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
977
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
978
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
979
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
980
+ """
981
+ utils.use_grad(self.bert, use_base_grad)
982
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
983
+
984
+ batch_size = input_ids.shape[0]
985
+ raw_embedding = self.bert.embeddings(
986
+ input_ids=input_ids,
987
+ position_ids=position_ids,
988
+ token_type_ids=token_type_ids,
989
+ )
990
+ prompts = self.get_prompt(batch_size=batch_size)
991
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
992
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
993
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
994
+
995
+ outputs = self.bert(
996
+ # input_ids,
997
+ attention_mask=attention_mask,
998
+ # token_type_ids=token_type_ids,
999
+ # position_ids=position_ids,
1000
+ head_mask=head_mask,
1001
+ inputs_embeds=inputs_embeds,
1002
+ output_attentions=output_attentions,
1003
+ output_hidden_states=output_hidden_states,
1004
+ return_dict=return_dict,
1005
+ # past_key_values=past_key_values,
1006
+ )
1007
+ sequence_output = outputs[0]
1008
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
1009
+ cls_token = sequence_output[:, 0]
1010
+ cls_token = self.dropout(cls_token)
1011
+ attentions = self.cls(cls_token).view(-1, self.config.vocab_size)
1012
+
1013
+ # compute loss
1014
+ masked_lm_loss = None
1015
+ if token_labels is not None:
1016
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
1017
+ else:
1018
+ if labels is not None:
1019
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
1020
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
1021
+
1022
+ # convert to binary classifier
1023
+ probs = []
1024
+ for y in self.clean_labels:
1025
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
1026
+ logits = torch.stack(probs).T
1027
+
1028
+ if not return_dict:
1029
+ output = (logits,) + outputs[2:]
1030
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1031
+ return SequenceClassifierOutput(
1032
+ loss=masked_lm_loss,
1033
+ logits=logits,
1034
+ hidden_states=outputs.hidden_states,
1035
+ attentions=attentions
1036
+ )
1037
+
1038
+
1039
+ class RobertaPrefixForMaskedLM(RobertaPreTrainedModel):
1040
+ def __init__(self, config):
1041
+ super().__init__(config)
1042
+ self.num_labels = config.num_labels
1043
+ self.config = config
1044
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1045
+ self.lm_head = RobertaLMHead(config)
1046
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
1047
+
1048
+ for param in self.roberta.parameters():
1049
+ param.requires_grad = False
1050
+
1051
+ self.pre_seq_len = config.pre_seq_len
1052
+ self.n_layer = config.num_hidden_layers
1053
+ self.n_head = config.num_attention_heads
1054
+ self.n_embd = config.hidden_size // config.num_attention_heads
1055
+
1056
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
1057
+ self.prefix_encoder = PrefixEncoder(config)
1058
+
1059
+ bert_param = 0
1060
+ for name, param in self.roberta.named_parameters():
1061
+ bert_param += param.numel()
1062
+ all_param = 0
1063
+ for name, param in self.named_parameters():
1064
+ all_param += param.numel()
1065
+ total_param = all_param - bert_param
1066
+ print('-> total param is {}'.format(total_param)) # 9860105
1067
+
1068
+ self.embedding = utils.get_embeddings(self, config)
1069
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
1070
+ self.clean_labels = torch.tensor(config.clean_labels).long()
1071
+
1072
+ def get_prompt(self, batch_size):
1073
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
1074
+ past_key_values = self.prefix_encoder(prefix_tokens)
1075
+ past_key_values = past_key_values.view(
1076
+ batch_size,
1077
+ self.pre_seq_len,
1078
+ self.n_layer * 2,
1079
+ self.n_head,
1080
+ self.n_embd
1081
+ )
1082
+ past_key_values = self.dropout(past_key_values)
1083
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
1084
+ return past_key_values
1085
+
1086
+ def forward(
1087
+ self,
1088
+ input_ids=None,
1089
+ attention_mask=None,
1090
+ token_type_ids=None,
1091
+ position_ids=None,
1092
+ head_mask=None,
1093
+ inputs_embeds=None,
1094
+ labels=None,
1095
+ token_labels=None,
1096
+ output_attentions=None,
1097
+ output_hidden_states=None,
1098
+ return_dict=None,
1099
+ use_base_grad=False,
1100
+ ):
1101
+ utils.use_grad(self.roberta, use_base_grad)
1102
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1103
+
1104
+ batch_size = input_ids.shape[0]
1105
+ past_key_values = self.get_prompt(batch_size=batch_size)
1106
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
1107
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
1108
+
1109
+ outputs = self.roberta(
1110
+ input_ids,
1111
+ attention_mask=attention_mask,
1112
+ token_type_ids=token_type_ids,
1113
+ position_ids=position_ids,
1114
+ head_mask=head_mask,
1115
+ inputs_embeds=inputs_embeds,
1116
+ output_attentions=output_attentions,
1117
+ output_hidden_states=output_hidden_states,
1118
+ return_dict=return_dict,
1119
+ past_key_values=past_key_values,
1120
+ )
1121
+ sequence_output = outputs[0]
1122
+ cls_token = sequence_output[:, 0]
1123
+ cls_token = self.dropout(cls_token)
1124
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size)
1125
+
1126
+ # compute loss
1127
+ masked_lm_loss = None
1128
+ if token_labels is not None:
1129
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
1130
+ else:
1131
+ if labels is not None:
1132
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
1133
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
1134
+
1135
+ # convert to binary classifier
1136
+ probs = []
1137
+ for y in self.clean_labels:
1138
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
1139
+ logits = torch.stack(probs).T
1140
+
1141
+ if not return_dict:
1142
+ output = (logits,) + outputs[2:]
1143
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1144
+ return SequenceClassifierOutput(
1145
+ loss=masked_lm_loss,
1146
+ logits=logits,
1147
+ hidden_states=outputs.hidden_states,
1148
+ attentions=attentions
1149
+ )
1150
+
1151
+
1152
+ class RobertaPromptForMaskedLM(RobertaPreTrainedModel):
1153
+ def __init__(self, config):
1154
+ super().__init__(config)
1155
+ self.num_labels = config.num_labels
1156
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
1157
+ self.lm_head = RobertaLMHead(config)
1158
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
1159
+ for param in self.roberta.parameters():
1160
+ param.requires_grad = False
1161
+
1162
+ self.pre_seq_len = config.pre_seq_len
1163
+ self.n_layer = config.num_hidden_layers
1164
+ self.n_head = config.num_attention_heads
1165
+ self.n_embd = config.hidden_size // config.num_attention_heads
1166
+
1167
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
1168
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
1169
+
1170
+ self.embeddings = self.roberta.embeddings
1171
+ self.embedding = utils.get_embeddings(self, config)
1172
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
1173
+ self.clean_labels = torch.tensor(config.clean_labels).long()
1174
+
1175
+ def get_prompt(self, batch_size):
1176
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
1177
+ prompts = self.prefix_encoder(prefix_tokens)
1178
+ return prompts
1179
+
1180
+ def forward(
1181
+ self,
1182
+ input_ids=None,
1183
+ attention_mask=None,
1184
+ token_type_ids=None,
1185
+ position_ids=None,
1186
+ head_mask=None,
1187
+ inputs_embeds=None,
1188
+ labels=None,
1189
+ token_labels=None,
1190
+ output_attentions=None,
1191
+ output_hidden_states=None,
1192
+ return_dict=None,
1193
+ use_base_grad=False
1194
+ ):
1195
+ utils.use_grad(self.roberta, use_base_grad)
1196
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1197
+
1198
+ batch_size = input_ids.shape[0]
1199
+ raw_embedding = self.roberta.embeddings(
1200
+ input_ids=input_ids,
1201
+ position_ids=position_ids,
1202
+ token_type_ids=token_type_ids,
1203
+ )
1204
+ prompts = self.get_prompt(batch_size=batch_size)
1205
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
1206
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
1207
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
1208
+
1209
+ outputs = self.roberta(
1210
+ # input_ids,
1211
+ attention_mask=attention_mask,
1212
+ # token_type_ids=token_type_ids,
1213
+ # position_ids=position_ids,
1214
+ head_mask=head_mask,
1215
+ inputs_embeds=inputs_embeds,
1216
+ output_attentions=output_attentions,
1217
+ output_hidden_states=output_hidden_states,
1218
+ return_dict=return_dict,
1219
+ # past_key_values=past_key_values,
1220
+ )
1221
+ sequence_output = outputs[0]
1222
+ sequence_output = self.dropout(sequence_output)
1223
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
1224
+ cls_token = sequence_output[:, 0]
1225
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size)
1226
+
1227
+ masked_lm_loss = None
1228
+ if token_labels is not None:
1229
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
1230
+ else:
1231
+ if labels is not None:
1232
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
1233
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
1234
+
1235
+ # convert to binary classifier
1236
+ probs = []
1237
+ for y in self.clean_labels:
1238
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
1239
+ logits = torch.stack(probs).T
1240
+
1241
+ if not return_dict:
1242
+ output = (logits,) + outputs[2:]
1243
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1244
+ return SequenceClassifierOutput(
1245
+ loss=masked_lm_loss,
1246
+ logits=logits,
1247
+ hidden_states=outputs.hidden_states,
1248
+ attentions=attentions
1249
+ )
soft_prompt/model/sequence_classification.py ADDED
@@ -0,0 +1,997 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch._C import NoopLogger
3
+ import torch.nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
7
+
8
+ from transformers import BertModel, BertPreTrainedModel
9
+ from transformers import RobertaModel, RobertaPreTrainedModel
10
+ from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast, BaseModelOutput, Seq2SeqLMOutput
11
+ from transformers import GPT2Model, GPT2PreTrainedModel, GPTNeoModel
12
+
13
+ from model.prefix_encoder import PrefixEncoder
14
+ from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout
15
+ from model import utils
16
+ import copy
17
+
18
+ class BertForSequenceClassification(BertPreTrainedModel):
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.num_labels = config.num_labels
22
+ self.config = config
23
+
24
+ self.bert = BertModel(config)
25
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
26
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
27
+
28
+ self.init_weights()
29
+
30
+ self.embedding = utils.get_embeddings(self, config)
31
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
32
+
33
+ def forward(
34
+ self,
35
+ input_ids=None,
36
+ attention_mask=None,
37
+ token_type_ids=None,
38
+ position_ids=None,
39
+ head_mask=None,
40
+ inputs_embeds=None,
41
+ labels=None,
42
+ output_attentions=None,
43
+ output_hidden_states=None,
44
+ return_dict=None,
45
+ use_base_grad=False,
46
+ ):
47
+ r"""
48
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
49
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
50
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
51
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
52
+ """
53
+ utils.use_grad(self.bert, use_base_grad)
54
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
55
+
56
+ outputs = self.bert(
57
+ input_ids,
58
+ attention_mask=attention_mask,
59
+ token_type_ids=token_type_ids,
60
+ position_ids=position_ids,
61
+ head_mask=head_mask,
62
+ inputs_embeds=inputs_embeds,
63
+ output_attentions=output_attentions,
64
+ output_hidden_states=output_hidden_states,
65
+ return_dict=return_dict,
66
+ )
67
+
68
+ pooled_output = outputs[1]
69
+
70
+ pooled_output = self.dropout(pooled_output)
71
+ logits = self.classifier(pooled_output)
72
+
73
+ loss = None
74
+ if labels is not None:
75
+ if self.config.problem_type is None:
76
+ if self.num_labels == 1:
77
+ self.config.problem_type = "regression"
78
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
79
+ self.config.problem_type = "single_label_classification"
80
+ else:
81
+ self.config.problem_type = "multi_label_classification"
82
+
83
+ if self.config.problem_type == "regression":
84
+ loss_fct = MSELoss()
85
+ if self.num_labels == 1:
86
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
87
+ else:
88
+ loss = loss_fct(logits, labels)
89
+ elif self.config.problem_type == "single_label_classification":
90
+ loss_fct = CrossEntropyLoss()
91
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
92
+ elif self.config.problem_type == "multi_label_classification":
93
+ loss_fct = BCEWithLogitsLoss()
94
+ loss = loss_fct(logits, labels)
95
+ elif self.config.problem_type == "em":
96
+ predict_logp = F.log_softmax(pooled_output, dim=-1)
97
+ target_logp = predict_logp.gather(-1, labels)
98
+ target_logp = target_logp - 1e32 * labels.eq(0) # Apply mask
99
+ loss = -torch.logsumexp(target_logp, dim=-1)
100
+
101
+ if not return_dict:
102
+ output = (logits,) + outputs[2:]
103
+ return ((loss,) + output) if loss is not None else output
104
+
105
+ loss.backward()
106
+
107
+ return SequenceClassifierOutput(
108
+ loss=loss,
109
+ logits=pooled_output,
110
+ hidden_states=outputs.hidden_states,
111
+ attentions=outputs.attentions,
112
+ )
113
+
114
+
115
+ class BertPrefixForSequenceClassification(BertPreTrainedModel):
116
+ def __init__(self, config):
117
+ super().__init__(config)
118
+ self.num_labels = config.num_labels
119
+ self.config = config
120
+ self.bert = BertModel(config)
121
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
122
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
123
+
124
+ for param in self.bert.parameters():
125
+ param.requires_grad = False
126
+
127
+ self.pre_seq_len = config.pre_seq_len
128
+ self.n_layer = config.num_hidden_layers
129
+ self.n_head = config.num_attention_heads
130
+ self.n_embd = config.hidden_size // config.num_attention_heads
131
+
132
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
133
+ self.prefix_encoder = PrefixEncoder(config)
134
+
135
+ bert_param = 0
136
+ for name, param in self.bert.named_parameters():
137
+ bert_param += param.numel()
138
+ all_param = 0
139
+ for name, param in self.named_parameters():
140
+ all_param += param.numel()
141
+ total_param = all_param - bert_param
142
+ print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(bert_param / 1000000, total_param))
143
+
144
+ self.embedding = utils.get_embeddings(self, config)
145
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
146
+
147
+ def get_prompt(self, batch_size):
148
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
149
+ past_key_values = self.prefix_encoder(prefix_tokens)
150
+ # bsz, seqlen, _ = past_key_values.shape
151
+ past_key_values = past_key_values.view(
152
+ batch_size,
153
+ self.pre_seq_len,
154
+ self.n_layer * 2,
155
+ self.n_head,
156
+ self.n_embd
157
+ )
158
+ past_key_values = self.dropout(past_key_values)
159
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
160
+ return past_key_values
161
+
162
+ def forward(
163
+ self,
164
+ input_ids=None,
165
+ attention_mask=None,
166
+ token_type_ids=None,
167
+ position_ids=None,
168
+ head_mask=None,
169
+ inputs_embeds=None,
170
+ labels=None,
171
+ output_attentions=None,
172
+ output_hidden_states=None,
173
+ return_dict=None,
174
+ use_base_grad=False,
175
+ ):
176
+ utils.use_grad(self.bert, use_base_grad)
177
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
178
+ batch_size = input_ids.shape[0]
179
+ past_key_values = self.get_prompt(batch_size=batch_size)
180
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
181
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
182
+
183
+ outputs = self.bert(
184
+ input_ids,
185
+ attention_mask=attention_mask,
186
+ token_type_ids=token_type_ids,
187
+ position_ids=position_ids,
188
+ head_mask=head_mask,
189
+ inputs_embeds=inputs_embeds,
190
+ output_attentions=output_attentions,
191
+ output_hidden_states=output_hidden_states,
192
+ return_dict=return_dict,
193
+ past_key_values=past_key_values,
194
+ )
195
+ pooled_output = outputs[1]
196
+ pooled_output = self.dropout(pooled_output)
197
+ logits = self.classifier(pooled_output)
198
+
199
+ loss = None
200
+ if labels is not None:
201
+ if self.config.problem_type is None:
202
+ if self.num_labels == 1:
203
+ self.config.problem_type = "regression"
204
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
205
+ self.config.problem_type = "single_label_classification"
206
+ else:
207
+ self.config.problem_type = "multi_label_classification"
208
+
209
+ if self.config.problem_type == "regression":
210
+ loss_fct = MSELoss()
211
+ if self.num_labels == 1:
212
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
213
+ else:
214
+ loss = loss_fct(logits, labels)
215
+ elif self.config.problem_type == "single_label_classification":
216
+ loss_fct = CrossEntropyLoss()
217
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
218
+ elif self.config.problem_type == "multi_label_classification":
219
+ loss_fct = BCEWithLogitsLoss()
220
+ loss = loss_fct(logits, labels)
221
+ if not return_dict:
222
+ output = (logits,) + outputs[2:]
223
+ return ((loss,) + output) if loss is not None else output
224
+
225
+ return SequenceClassifierOutput(
226
+ loss=loss,
227
+ logits=logits,
228
+ hidden_states=outputs.hidden_states,
229
+ attentions=outputs.attentions,
230
+ )
231
+
232
+
233
+ class BertPromptForSequenceClassification(BertPreTrainedModel):
234
+ def __init__(self, config):
235
+ super().__init__(config)
236
+ self.num_labels = config.num_labels
237
+ self.bert = BertModel(config)
238
+ self.embeddings = self.bert.embeddings
239
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
240
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
241
+
242
+ for param in self.bert.parameters():
243
+ param.requires_grad = False
244
+
245
+ self.pre_seq_len = config.pre_seq_len
246
+ self.n_layer = config.num_hidden_layers
247
+ self.n_head = config.num_attention_heads
248
+ self.n_embd = config.hidden_size // config.num_attention_heads
249
+
250
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
251
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
252
+
253
+ self.embedding = utils.get_embeddings(self, config)
254
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
255
+
256
+ def get_prompt(self, batch_size):
257
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
258
+ prompts = self.prefix_encoder(prefix_tokens)
259
+ return prompts
260
+
261
+ def forward(
262
+ self,
263
+ input_ids=None,
264
+ attention_mask=None,
265
+ token_type_ids=None,
266
+ position_ids=None,
267
+ head_mask=None,
268
+ inputs_embeds=None,
269
+ labels=None,
270
+ output_attentions=None,
271
+ output_hidden_states=None,
272
+ return_dict=None,
273
+ use_base_grad=False,
274
+ ):
275
+ utils.use_grad(self.bert, use_base_grad)
276
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
277
+
278
+ batch_size = input_ids.shape[0]
279
+ raw_embedding = self.embeddings(
280
+ input_ids=input_ids,
281
+ position_ids=position_ids,
282
+ token_type_ids=token_type_ids,
283
+ )
284
+ prompts = self.get_prompt(batch_size=batch_size)
285
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
286
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
287
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
288
+
289
+ outputs = self.bert(
290
+ # input_ids,
291
+ attention_mask=attention_mask,
292
+ # token_type_ids=token_type_ids,
293
+ # position_ids=position_ids,
294
+ head_mask=head_mask,
295
+ inputs_embeds=inputs_embeds,
296
+ output_attentions=output_attentions,
297
+ output_hidden_states=output_hidden_states,
298
+ return_dict=return_dict,
299
+ # past_key_values=past_key_values,
300
+ )
301
+
302
+ # pooled_output = outputs[1]
303
+ sequence_output = outputs[0]
304
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
305
+ first_token_tensor = sequence_output[:, 0]
306
+ pooled_output = self.bert.pooler.dense(first_token_tensor)
307
+ pooled_output = self.bert.pooler.activation(pooled_output)
308
+
309
+ pooled_output = self.dropout(pooled_output)
310
+ logits = self.classifier(pooled_output)
311
+
312
+ loss = None
313
+ if labels is not None:
314
+ if self.config.problem_type is None:
315
+ if self.num_labels == 1:
316
+ self.config.problem_type = "regression"
317
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
318
+ self.config.problem_type = "single_label_classification"
319
+ else:
320
+ self.config.problem_type = "multi_label_classification"
321
+
322
+ if self.config.problem_type == "regression":
323
+ loss_fct = MSELoss()
324
+ if self.num_labels == 1:
325
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
326
+ else:
327
+ loss = loss_fct(logits, labels)
328
+ elif self.config.problem_type == "single_label_classification":
329
+ loss_fct = CrossEntropyLoss()
330
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
331
+ elif self.config.problem_type == "multi_label_classification":
332
+ loss_fct = BCEWithLogitsLoss()
333
+ loss = loss_fct(logits, labels)
334
+ if not return_dict:
335
+ output = (logits,) + outputs[2:]
336
+ return ((loss,) + output) if loss is not None else output
337
+
338
+ return SequenceClassifierOutput(
339
+ loss=loss,
340
+ logits=logits,
341
+ hidden_states=outputs.hidden_states,
342
+ attentions=outputs.attentions,
343
+ )
344
+
345
+
346
+ class RobertaPrefixForSequenceClassification(RobertaPreTrainedModel):
347
+ def __init__(self, config):
348
+ super().__init__(config)
349
+ self.num_labels = config.num_labels
350
+ self.config = config
351
+ self.roberta = RobertaModel(config)
352
+
353
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
354
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
355
+ self.init_weights()
356
+
357
+ for param in self.roberta.parameters():
358
+ param.requires_grad = False
359
+
360
+ self.pre_seq_len = config.pre_seq_len
361
+ self.n_layer = config.num_hidden_layers
362
+ self.n_head = config.num_attention_heads
363
+ self.n_embd = config.hidden_size // config.num_attention_heads
364
+
365
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
366
+ self.prefix_encoder = PrefixEncoder(config)
367
+
368
+ bert_param = 0
369
+ for name, param in self.roberta.named_parameters():
370
+ bert_param += param.numel()
371
+ all_param = 0
372
+ for name, param in self.named_parameters():
373
+ all_param += param.numel()
374
+ total_param = all_param - bert_param
375
+ print('-> total param is {}'.format(total_param)) # 9860105
376
+
377
+ self.embedding = utils.get_embeddings(self, config)
378
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
379
+
380
+ def get_prompt(self, batch_size):
381
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
382
+ past_key_values = self.prefix_encoder(prefix_tokens)
383
+ past_key_values = past_key_values.view(
384
+ batch_size,
385
+ self.pre_seq_len,
386
+ self.n_layer * 2,
387
+ self.n_head,
388
+ self.n_embd
389
+ )
390
+ past_key_values = self.dropout(past_key_values)
391
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
392
+ return past_key_values
393
+
394
+ def forward(
395
+ self,
396
+ input_ids=None,
397
+ attention_mask=None,
398
+ token_type_ids=None,
399
+ position_ids=None,
400
+ head_mask=None,
401
+ inputs_embeds=None,
402
+ labels=None,
403
+ output_attentions=None,
404
+ output_hidden_states=None,
405
+ return_dict=None,
406
+ use_base_grad=False,
407
+ ):
408
+ utils.use_grad(self.roberta, use_base_grad)
409
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
410
+
411
+ batch_size = input_ids.shape[0]
412
+ past_key_values = self.get_prompt(batch_size=batch_size)
413
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
414
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
415
+
416
+ outputs = self.roberta(
417
+ input_ids,
418
+ attention_mask=attention_mask,
419
+ token_type_ids=token_type_ids,
420
+ position_ids=position_ids,
421
+ head_mask=head_mask,
422
+ inputs_embeds=inputs_embeds,
423
+ output_attentions=output_attentions,
424
+ output_hidden_states=output_hidden_states,
425
+ return_dict=return_dict,
426
+ past_key_values=past_key_values,
427
+ )
428
+
429
+ pooled_output = outputs[1]
430
+ pooled_output = self.dropout(pooled_output)
431
+ logits = self.classifier(pooled_output)
432
+
433
+ loss = None
434
+ if labels is not None:
435
+ if self.config.problem_type is None:
436
+ if self.num_labels == 1:
437
+ self.config.problem_type = "regression"
438
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
439
+ self.config.problem_type = "single_label_classification"
440
+ else:
441
+ self.config.problem_type = "multi_label_classification"
442
+
443
+ if self.config.problem_type == "regression":
444
+ loss_fct = MSELoss()
445
+ if self.num_labels == 1:
446
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
447
+ else:
448
+ loss = loss_fct(logits, labels)
449
+ elif self.config.problem_type == "single_label_classification":
450
+ loss_fct = CrossEntropyLoss()
451
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
452
+ elif self.config.problem_type == "multi_label_classification":
453
+ loss_fct = BCEWithLogitsLoss()
454
+ loss = loss_fct(logits, labels)
455
+ if not return_dict:
456
+ output = (logits,) + outputs[2:]
457
+ return ((loss,) + output) if loss is not None else output
458
+
459
+ return SequenceClassifierOutput(
460
+ loss=loss,
461
+ logits=logits,
462
+ hidden_states=outputs.hidden_states,
463
+ attentions=outputs.attentions,
464
+ )
465
+
466
+
467
+ class RobertaPromptForSequenceClassification(RobertaPreTrainedModel):
468
+ def __init__(self, config):
469
+ super().__init__(config)
470
+ self.num_labels = config.num_labels
471
+ self.roberta = RobertaModel(config)
472
+ self.embeddings = self.roberta.embeddings
473
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
474
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
475
+
476
+ for param in self.roberta.parameters():
477
+ param.requires_grad = False
478
+
479
+ self.pre_seq_len = config.pre_seq_len
480
+ self.n_layer = config.num_hidden_layers
481
+ self.n_head = config.num_attention_heads
482
+ self.n_embd = config.hidden_size // config.num_attention_heads
483
+
484
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
485
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
486
+
487
+ self.embedding = utils.get_embeddings(self, config)
488
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
489
+
490
+ def get_prompt(self, batch_size):
491
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
492
+ prompts = self.prefix_encoder(prefix_tokens)
493
+ return prompts
494
+
495
+ def forward(
496
+ self,
497
+ input_ids=None,
498
+ attention_mask=None,
499
+ token_type_ids=None,
500
+ position_ids=None,
501
+ head_mask=None,
502
+ inputs_embeds=None,
503
+ labels=None,
504
+ output_attentions=None,
505
+ output_hidden_states=None,
506
+ return_dict=None,
507
+ use_base_grad=False
508
+ ):
509
+ utils.use_grad(self.roberta, use_base_grad)
510
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
511
+
512
+ batch_size = input_ids.shape[0]
513
+ raw_embedding = self.embeddings(
514
+ input_ids=input_ids,
515
+ position_ids=position_ids,
516
+ token_type_ids=token_type_ids,
517
+ )
518
+ prompts = self.get_prompt(batch_size=batch_size)
519
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
520
+
521
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
522
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
523
+
524
+ outputs = self.roberta(
525
+ # input_ids,
526
+ attention_mask=attention_mask,
527
+ # token_type_ids=token_type_ids,
528
+ # position_ids=position_ids,
529
+ head_mask=head_mask,
530
+ inputs_embeds=inputs_embeds,
531
+ output_attentions=output_attentions,
532
+ output_hidden_states=output_hidden_states,
533
+ return_dict=return_dict,
534
+ # past_key_values=past_key_values,
535
+ )
536
+
537
+ # pooled_output = outputs[1]
538
+ sequence_output = outputs[0]
539
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
540
+ first_token_tensor = sequence_output[:, 0]
541
+ pooled_output = self.roberta.pooler.dense(first_token_tensor)
542
+ pooled_output = self.roberta.pooler.activation(pooled_output)
543
+
544
+ pooled_output = self.dropout(pooled_output)
545
+ logits = self.classifier(pooled_output)
546
+
547
+ loss = None
548
+ if labels is not None:
549
+ if self.config.problem_type is None:
550
+ if self.num_labels == 1:
551
+ self.config.problem_type = "regression"
552
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
553
+ self.config.problem_type = "single_label_classification"
554
+ else:
555
+ self.config.problem_type = "multi_label_classification"
556
+
557
+ if self.config.problem_type == "regression":
558
+ loss_fct = MSELoss()
559
+ if self.num_labels == 1:
560
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
561
+ else:
562
+ loss = loss_fct(logits, labels)
563
+ elif self.config.problem_type == "single_label_classification":
564
+ loss_fct = CrossEntropyLoss()
565
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
566
+ elif self.config.problem_type == "multi_label_classification":
567
+ loss_fct = BCEWithLogitsLoss()
568
+ loss = loss_fct(logits, labels)
569
+ if not return_dict:
570
+ output = (logits,) + outputs[2:]
571
+ return ((loss,) + output) if loss is not None else output
572
+
573
+ return SequenceClassifierOutput(
574
+ loss=loss,
575
+ logits=logits,
576
+ hidden_states=outputs.hidden_states,
577
+ attentions=outputs.attentions,
578
+ )
579
+
580
+
581
+ class DebertaPrefixForSequenceClassification(DebertaPreTrainedModel):
582
+ def __init__(self, config):
583
+ super().__init__(config)
584
+ self.num_labels = config.num_labels
585
+ self.config = config
586
+ self.deberta = DebertaModel(config)
587
+ self.pooler = ContextPooler(config)
588
+ output_dim = self.pooler.output_dim
589
+ self.classifier = torch.nn.Linear(output_dim, self.num_labels)
590
+ self.dropout = StableDropout(config.hidden_dropout_prob)
591
+ self.init_weights()
592
+
593
+ for param in self.deberta.parameters():
594
+ param.requires_grad = False
595
+
596
+ self.pre_seq_len = config.pre_seq_len
597
+ self.n_layer = config.num_hidden_layers
598
+ self.n_head = config.num_attention_heads
599
+ self.n_embd = config.hidden_size // config.num_attention_heads
600
+
601
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
602
+ self.prefix_encoder = PrefixEncoder(config)
603
+
604
+ deberta_param = 0
605
+ for name, param in self.deberta.named_parameters():
606
+ deberta_param += param.numel()
607
+ all_param = 0
608
+ for name, param in self.named_parameters():
609
+ all_param += param.numel()
610
+ total_param = all_param - deberta_param
611
+ print('total param is {}'.format(total_param)) # 9860105
612
+
613
+ self.embedding = utils.get_embeddings(self, config)
614
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
615
+
616
+ def get_prompt(self, batch_size):
617
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
618
+ past_key_values = self.prefix_encoder(prefix_tokens)
619
+ # bsz, seqlen, _ = past_key_values.shape
620
+ past_key_values = past_key_values.view(
621
+ batch_size,
622
+ self.pre_seq_len,
623
+ self.n_layer * 2,
624
+ self.n_head,
625
+ self.n_embd
626
+ )
627
+ past_key_values = self.dropout(past_key_values)
628
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
629
+ return past_key_values
630
+
631
+ def forward(
632
+ self,
633
+ input_ids=None,
634
+ attention_mask=None,
635
+ token_type_ids=None,
636
+ position_ids=None,
637
+ head_mask=None,
638
+ inputs_embeds=None,
639
+ labels=None,
640
+ output_attentions=None,
641
+ output_hidden_states=None,
642
+ return_dict=None,
643
+ use_base_grad=False
644
+ ):
645
+ utils.use_grad(self.bert, use_base_grad)
646
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
647
+ batch_size = input_ids.shape[0]
648
+ past_key_values = self.get_prompt(batch_size=batch_size)
649
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
650
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
651
+
652
+ outputs = self.deberta(
653
+ input_ids,
654
+ attention_mask=attention_mask,
655
+ token_type_ids=token_type_ids,
656
+ position_ids=position_ids,
657
+ inputs_embeds=inputs_embeds,
658
+ output_attentions=output_attentions,
659
+ output_hidden_states=output_hidden_states,
660
+ return_dict=return_dict,
661
+ past_key_values=past_key_values,
662
+ )
663
+
664
+ encoder_layer = outputs[0]
665
+ pooled_output = self.pooler(encoder_layer)
666
+ pooled_output = self.dropout(pooled_output)
667
+ logits = self.classifier(pooled_output)
668
+
669
+ loss = None
670
+ if labels is not None:
671
+ if self.num_labels == 1:
672
+ # regression task
673
+ loss_fn = torch.nn.MSELoss()
674
+ logits = logits.view(-1).to(labels.dtype)
675
+ loss = loss_fn(logits, labels.view(-1))
676
+ elif labels.dim() == 1 or labels.size(-1) == 1:
677
+ label_index = (labels >= 0).nonzero()
678
+ labels = labels.long()
679
+ if label_index.size(0) > 0:
680
+ labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
681
+ labels = torch.gather(labels, 0, label_index.view(-1))
682
+ loss_fct = CrossEntropyLoss()
683
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
684
+ else:
685
+ loss = torch.tensor(0).to(logits)
686
+ else:
687
+ log_softmax = torch.nn.LogSoftmax(-1)
688
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
689
+ if not return_dict:
690
+ output = (logits,) + outputs[1:]
691
+ return ((loss,) + output) if loss is not None else output
692
+ else:
693
+ return SequenceClassifierOutput(
694
+ loss=loss,
695
+ logits=logits,
696
+ hidden_states=outputs.hidden_states,
697
+ attentions=outputs.attentions,
698
+ )
699
+
700
+
701
+ class GPT2PromptForSequenceClassification(GPT2PreTrainedModel):
702
+ def __init__(self, config):
703
+ super().__init__(config)
704
+ self.num_labels = config.num_labels
705
+ self.config = config
706
+ self.gpt2 = GPT2Model(config)
707
+ self.dropout = StableDropout(config.embd_pdrop)
708
+ self.classifier = torch.nn.Linear(config.n_embd, self.num_labels)
709
+
710
+ for param in self.gpt2.parameters():
711
+ param.requires_grad = False
712
+
713
+ self.pre_seq_len = config.pre_seq_len
714
+ self.n_layer = config.num_hidden_layers
715
+ self.n_head = config.num_attention_heads
716
+ self.n_embd = config.hidden_size // config.num_attention_heads
717
+
718
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
719
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
720
+
721
+ # Model parallel
722
+ self.model_parallel = False
723
+ self.device_map = None
724
+
725
+ gpt2_param = 0
726
+ for name, param in self.gpt2.named_parameters():
727
+ gpt2_param += param.numel()
728
+ all_param = 0
729
+ for name, param in self.named_parameters():
730
+ all_param += param.numel()
731
+ total_param = all_param - gpt2_param
732
+ print('-> total param is {}'.format(total_param)) # 9860105
733
+
734
+ self.embedding = self.gpt2.wte
735
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
736
+
737
+ def get_prompt(self, batch_size):
738
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.gpt2.device)
739
+ prompts = self.prefix_encoder(prefix_tokens)
740
+ return prompts
741
+
742
+ def forward(
743
+ self,
744
+ input_ids=None,
745
+ attention_mask=None,
746
+ token_type_ids=None,
747
+ position_ids=None,
748
+ head_mask=None,
749
+ inputs_embeds=None,
750
+ labels=None,
751
+ output_attentions=None,
752
+ output_hidden_states=None,
753
+ return_dict=None,
754
+ use_base_grad=False
755
+ ):
756
+ utils.use_grad(self.gpt2, use_base_grad)
757
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
758
+
759
+ batch_size = input_ids.shape[0]
760
+ raw_embedding = self.embedding(input_ids)
761
+ prompts = self.get_prompt(batch_size=batch_size)
762
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
763
+
764
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.gpt2.device)
765
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
766
+
767
+ transformer_outputs = self.gpt2(
768
+ # input_ids,
769
+ attention_mask=attention_mask,
770
+ # token_type_ids=token_type_ids,
771
+ # position_ids=position_ids,
772
+ head_mask=head_mask,
773
+ inputs_embeds=inputs_embeds,
774
+ output_attentions=output_attentions,
775
+ output_hidden_states=output_hidden_states,
776
+ return_dict=return_dict,
777
+ # past_key_values=past_key_values,
778
+ )
779
+
780
+ hidden_states = transformer_outputs[0]
781
+ logits = self.classifier(hidden_states)
782
+
783
+ if input_ids is not None:
784
+ batch_size, sequence_length = input_ids.shape[:2]
785
+ else:
786
+ batch_size, sequence_length = inputs_embeds.shape[:2]
787
+
788
+ assert (
789
+ self.config.pad_token_id is not None or batch_size == 1
790
+ ), "Cannot handle batch sizes > 1 if no " \
791
+ "padding token is defined."
792
+ if self.config.pad_token_id is None:
793
+ sequence_lengths = -1
794
+ else:
795
+ if input_ids is not None:
796
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
797
+ else:
798
+ sequence_lengths = -1
799
+
800
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
801
+
802
+ loss = None
803
+ if labels is not None:
804
+ if self.config.problem_type is None:
805
+ if self.num_labels == 1:
806
+ self.config.problem_type = "regression"
807
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
808
+ self.config.problem_type = "single_label_classification"
809
+ else:
810
+ self.config.problem_type = "multi_label_classification"
811
+
812
+ if self.config.problem_type == "regression":
813
+ loss_fct = MSELoss()
814
+ if self.num_labels == 1:
815
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
816
+ else:
817
+ loss = loss_fct(pooled_logits, labels)
818
+ elif self.config.problem_type == "single_label_classification":
819
+ loss_fct = CrossEntropyLoss()
820
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
821
+ elif self.config.problem_type == "multi_label_classification":
822
+ loss_fct = BCEWithLogitsLoss()
823
+ loss = loss_fct(pooled_logits, labels)
824
+ if not return_dict:
825
+ output = (pooled_logits,) + transformer_outputs[1:]
826
+ return ((loss,) + output) if loss is not None else output
827
+
828
+ return SequenceClassifierOutputWithPast(
829
+ loss=loss,
830
+ logits=pooled_logits,
831
+ past_key_values=transformer_outputs.past_key_values,
832
+ hidden_states=transformer_outputs.hidden_states,
833
+ attentions=transformer_outputs.attentions,
834
+ )
835
+
836
+ class GPT2PrefixForSequenceClassification(GPT2PreTrainedModel):
837
+ def __init__(self, config):
838
+ super().__init__(config)
839
+ self.num_labels = config.num_labels
840
+ self.config = config
841
+ self.gpt2 = GPT2Model(config)
842
+ self.dropout = StableDropout(config.hidden_dropout_prob)
843
+ self.classifier = torch.nn.Linear(config.n_embd, self.num_labels)
844
+
845
+ for param in self.gpt2.parameters():
846
+ param.requires_grad = False
847
+
848
+ self.pre_seq_len = config.pre_seq_len
849
+ self.n_layer = config.num_hidden_layers
850
+ self.n_head = config.num_attention_heads
851
+ self.n_embd = config.hidden_size // config.num_attention_heads
852
+
853
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
854
+ self.prefix_encoder = PrefixEncoder(config)
855
+
856
+ # Model parallel
857
+ self.model_parallel = False
858
+ self.device_map = None
859
+
860
+ gpt2_param = 0
861
+ for name, param in self.gpt2.named_parameters():
862
+ gpt2_param += param.numel()
863
+ all_param = 0
864
+ for name, param in self.named_parameters():
865
+ all_param += param.numel()
866
+ total_param = all_param - gpt2_param
867
+ print('-> gpt2_param:{:0.2f}M P-tuning-V2 param is {}'.format(gpt2_param/1000000, total_param))
868
+
869
+ self.embedding = self.gpt2.wte
870
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
871
+
872
+ def get_prompt(self, batch_size):
873
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.gpt2.device)
874
+ past_key_values = self.prefix_encoder(prefix_tokens)
875
+ past_key_values = past_key_values.view(
876
+ batch_size,
877
+ self.pre_seq_len,
878
+ self.n_layer * 2,
879
+ self.n_head,
880
+ self.n_embd
881
+ )
882
+ past_key_values = self.dropout(past_key_values)
883
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
884
+ return past_key_values
885
+
886
+ def forward(
887
+ self,
888
+ input_ids=None,
889
+ attention_mask=None,
890
+ token_type_ids=None,
891
+ position_ids=None,
892
+ head_mask=None,
893
+ inputs_embeds=None,
894
+ labels=None,
895
+ output_attentions=None,
896
+ output_hidden_states=None,
897
+ return_dict=None,
898
+ use_base_grad=False,
899
+ use_cache=None
900
+ ):
901
+ r"""
902
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
903
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
904
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
905
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
906
+ """
907
+ utils.use_grad(self.gpt2, use_base_grad)
908
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
909
+
910
+ batch_size = input_ids.shape[0]
911
+ past_key_values = self.get_prompt(batch_size=batch_size)
912
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.gpt2.device)
913
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
914
+
915
+ transformer_outputs = self.gpt2(
916
+ input_ids,
917
+ past_key_values=past_key_values,
918
+ attention_mask=attention_mask,
919
+ token_type_ids=token_type_ids,
920
+ position_ids=position_ids,
921
+ head_mask=head_mask,
922
+ inputs_embeds=inputs_embeds,
923
+ use_cache=use_cache,
924
+ output_attentions=output_attentions,
925
+ output_hidden_states=output_hidden_states,
926
+ return_dict=return_dict,
927
+ )
928
+ hidden_states = transformer_outputs[0]
929
+ logits = self.classifier(hidden_states)
930
+
931
+ if input_ids is not None:
932
+ batch_size, sequence_length = input_ids.shape[:2]
933
+ else:
934
+ batch_size, sequence_length = inputs_embeds.shape[:2]
935
+
936
+ assert (
937
+ self.config.pad_token_id is not None or batch_size == 1
938
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
939
+ if self.config.pad_token_id is None:
940
+ sequence_lengths = -1
941
+ else:
942
+ if input_ids is not None:
943
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
944
+ else:
945
+ sequence_lengths = -1
946
+
947
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
948
+
949
+ loss = None
950
+ if labels is not None:
951
+ if self.config.problem_type is None:
952
+ if self.num_labels == 1:
953
+ self.config.problem_type = "regression"
954
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
955
+ self.config.problem_type = "single_label_classification"
956
+ else:
957
+ self.config.problem_type = "multi_label_classification"
958
+
959
+ if self.config.problem_type == "regression":
960
+ loss_fct = MSELoss()
961
+ if self.num_labels == 1:
962
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
963
+ else:
964
+ loss = loss_fct(pooled_logits, labels)
965
+ elif self.config.problem_type == "single_label_classification":
966
+ loss_fct = CrossEntropyLoss()
967
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
968
+ elif self.config.problem_type == "multi_label_classification":
969
+ loss_fct = BCEWithLogitsLoss()
970
+ loss = loss_fct(pooled_logits, labels)
971
+ if not return_dict:
972
+ output = (pooled_logits,) + transformer_outputs[1:]
973
+ return ((loss,) + output) if loss is not None else output
974
+
975
+ return SequenceClassifierOutputWithPast(
976
+ loss=loss,
977
+ logits=pooled_logits,
978
+ past_key_values=transformer_outputs.past_key_values,
979
+ hidden_states=transformer_outputs.hidden_states,
980
+ attentions=transformer_outputs.attentions,
981
+ )
982
+
983
+
984
+ if __name__ == "__main__":
985
+ from transformers import AutoConfig
986
+ config = AutoConfig.from_pretrained("gpt2-large")
987
+ config.hidden_dropout_prob = 0.1
988
+ config.pre_seq_len = 128
989
+ config.prefix_projection = True
990
+ config.num_labels = 2
991
+ config.prefix_hidden_size = 1024
992
+ model = GPT2PrefixForSequenceClassification(config)
993
+
994
+ for name, param in model.named_parameters():
995
+ print(name, param.shape)
996
+
997
+
soft_prompt/model/token_classification.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ from torch.nn import CrossEntropyLoss
6
+
7
+ from transformers import BertModel, BertPreTrainedModel
8
+ from transformers import RobertaModel, RobertaPreTrainedModel
9
+ from transformers.modeling_outputs import TokenClassifierOutput
10
+
11
+ from model.prefix_encoder import PrefixEncoder
12
+ from model.deberta import DebertaModel, DebertaPreTrainedModel
13
+ from model.debertaV2 import DebertaV2Model, DebertaV2PreTrainedModel
14
+
15
+ class BertForTokenClassification(BertPreTrainedModel):
16
+
17
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
18
+
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.num_labels = config.num_labels
22
+
23
+ self.bert = BertModel(config, add_pooling_layer=False)
24
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
25
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
26
+
27
+ only_cls_head = True # False in SRL
28
+ if only_cls_head:
29
+ for param in self.bert.parameters():
30
+ param.requires_grad = False
31
+
32
+ self.init_weights()
33
+
34
+ bert_param = 0
35
+ for name, param in self.bert.named_parameters():
36
+ bert_param += param.numel()
37
+ all_param = 0
38
+ for name, param in self.named_parameters():
39
+ all_param += param.numel()
40
+ total_param = all_param - bert_param
41
+ print('total param is {}'.format(total_param))
42
+
43
+
44
+ def forward(
45
+ self,
46
+ input_ids=None,
47
+ attention_mask=None,
48
+ token_type_ids=None,
49
+ position_ids=None,
50
+ head_mask=None,
51
+ inputs_embeds=None,
52
+ labels=None,
53
+ output_attentions=None,
54
+ output_hidden_states=None,
55
+ return_dict=None,
56
+ ):
57
+ r"""
58
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
59
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
60
+ 1]``.
61
+ """
62
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
63
+
64
+ outputs = self.bert(
65
+ input_ids,
66
+ attention_mask=attention_mask,
67
+ token_type_ids=token_type_ids,
68
+ position_ids=position_ids,
69
+ head_mask=head_mask,
70
+ inputs_embeds=inputs_embeds,
71
+ output_attentions=output_attentions,
72
+ output_hidden_states=output_hidden_states,
73
+ return_dict=return_dict,
74
+ )
75
+
76
+ sequence_output = outputs[0]
77
+
78
+ sequence_output = self.dropout(sequence_output)
79
+ logits = self.classifier(sequence_output)
80
+
81
+ loss = None
82
+ if labels is not None:
83
+ loss_fct = CrossEntropyLoss()
84
+ # Only keep active parts of the loss
85
+ if attention_mask is not None:
86
+ active_loss = attention_mask.view(-1) == 1
87
+ active_logits = logits.view(-1, self.num_labels)
88
+ active_labels = torch.where(
89
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
90
+ )
91
+ loss = loss_fct(active_logits, active_labels)
92
+ else:
93
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
94
+
95
+ if not return_dict:
96
+ output = (logits,) + outputs[2:]
97
+ return ((loss,) + output) if loss is not None else output
98
+
99
+ return TokenClassifierOutput(
100
+ loss=loss,
101
+ logits=logits,
102
+ hidden_states=outputs.hidden_states,
103
+ attentions=outputs.attentions,
104
+ )
105
+
106
+
107
+ class BertPrefixForTokenClassification(BertPreTrainedModel):
108
+ def __init__(self, config):
109
+ super().__init__(config)
110
+ self.num_labels = config.num_labels
111
+ self.bert = BertModel(config, add_pooling_layer=False)
112
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
113
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
114
+
115
+ from_pretrained = False
116
+ if from_pretrained:
117
+ self.classifier.load_state_dict(torch.load('model/checkpoint.pkl'))
118
+
119
+ for param in self.bert.parameters():
120
+ param.requires_grad = False
121
+
122
+ self.pre_seq_len = config.pre_seq_len
123
+ self.n_layer = config.num_hidden_layers
124
+ self.n_head = config.num_attention_heads
125
+ self.n_embd = config.hidden_size // config.num_attention_heads
126
+
127
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
128
+ self.prefix_encoder = PrefixEncoder(config)
129
+
130
+
131
+ bert_param = 0
132
+ for name, param in self.bert.named_parameters():
133
+ bert_param += param.numel()
134
+ all_param = 0
135
+ for name, param in self.named_parameters():
136
+ all_param += param.numel()
137
+ total_param = all_param - bert_param
138
+ print('total param is {}'.format(total_param)) # 9860105
139
+
140
+ def get_prompt(self, batch_size):
141
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
142
+ past_key_values = self.prefix_encoder(prefix_tokens)
143
+ # bsz, seqlen, _ = past_key_values.shape
144
+ past_key_values = past_key_values.view(
145
+ batch_size,
146
+ self.pre_seq_len,
147
+ self.n_layer * 2,
148
+ self.n_head,
149
+ self.n_embd
150
+ )
151
+ past_key_values = self.dropout(past_key_values)
152
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
153
+ return past_key_values
154
+
155
+ def forward(
156
+ self,
157
+ input_ids=None,
158
+ attention_mask=None,
159
+ token_type_ids=None,
160
+ position_ids=None,
161
+ head_mask=None,
162
+ inputs_embeds=None,
163
+ labels=None,
164
+ output_attentions=None,
165
+ output_hidden_states=None,
166
+ return_dict=None,
167
+ ):
168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
+
170
+ batch_size = input_ids.shape[0]
171
+ past_key_values = self.get_prompt(batch_size=batch_size)
172
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
173
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
174
+
175
+ outputs = self.bert(
176
+ input_ids,
177
+ attention_mask=attention_mask,
178
+ token_type_ids=token_type_ids,
179
+ position_ids=position_ids,
180
+ head_mask=head_mask,
181
+ inputs_embeds=inputs_embeds,
182
+ output_attentions=output_attentions,
183
+ output_hidden_states=output_hidden_states,
184
+ return_dict=return_dict,
185
+ past_key_values=past_key_values,
186
+ )
187
+
188
+ sequence_output = outputs[0]
189
+ sequence_output = self.dropout(sequence_output)
190
+ logits = self.classifier(sequence_output)
191
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
192
+
193
+ loss = None
194
+ if labels is not None:
195
+ loss_fct = CrossEntropyLoss()
196
+ # Only keep active parts of the loss
197
+ if attention_mask is not None:
198
+ active_loss = attention_mask.view(-1) == 1
199
+ active_logits = logits.view(-1, self.num_labels)
200
+ active_labels = torch.where(
201
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
202
+ )
203
+ loss = loss_fct(active_logits, active_labels)
204
+ else:
205
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
206
+
207
+ if not return_dict:
208
+ output = (logits,) + outputs[2:]
209
+ return ((loss,) + output) if loss is not None else output
210
+
211
+ return TokenClassifierOutput(
212
+ loss=loss,
213
+ logits=logits,
214
+ hidden_states=outputs.hidden_states,
215
+ attentions=outputs.attentions,
216
+ )
217
+
218
+
219
+
220
+
221
+ class RobertaPrefixForTokenClassification(RobertaPreTrainedModel):
222
+ def __init__(self, config):
223
+ super().__init__(config)
224
+ self.num_labels = config.num_labels
225
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
226
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
227
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
228
+ self.init_weights()
229
+
230
+ for param in self.roberta.parameters():
231
+ param.requires_grad = False
232
+
233
+ self.pre_seq_len = config.pre_seq_len
234
+ self.n_layer = config.num_hidden_layers
235
+ self.n_head = config.num_attention_heads
236
+ self.n_embd = config.hidden_size // config.num_attention_heads
237
+
238
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
239
+ self.prefix_encoder = PrefixEncoder(config)
240
+
241
+ bert_param = 0
242
+ for name, param in self.roberta.named_parameters():
243
+ bert_param += param.numel()
244
+ all_param = 0
245
+ for name, param in self.named_parameters():
246
+ all_param += param.numel()
247
+ total_param = all_param - bert_param
248
+ print('total param is {}'.format(total_param)) # 9860105
249
+
250
+
251
+ def get_prompt(self, batch_size):
252
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
253
+ past_key_values = self.prefix_encoder(prefix_tokens)
254
+ past_key_values = past_key_values.view(
255
+ batch_size,
256
+ self.pre_seq_len,
257
+ self.n_layer * 2,
258
+ self.n_head,
259
+ self.n_embd
260
+ )
261
+ past_key_values = self.dropout(past_key_values)
262
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
263
+ return past_key_values
264
+
265
+ def forward(
266
+ self,
267
+ input_ids=None,
268
+ attention_mask=None,
269
+ token_type_ids=None,
270
+ position_ids=None,
271
+ head_mask=None,
272
+ inputs_embeds=None,
273
+ labels=None,
274
+ output_attentions=None,
275
+ output_hidden_states=None,
276
+ return_dict=None,
277
+ ):
278
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
279
+
280
+ batch_size = input_ids.shape[0]
281
+ past_key_values = self.get_prompt(batch_size=batch_size)
282
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
283
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
284
+
285
+ outputs = self.roberta(
286
+ input_ids,
287
+ attention_mask=attention_mask,
288
+ token_type_ids=token_type_ids,
289
+ position_ids=position_ids,
290
+ head_mask=head_mask,
291
+ inputs_embeds=inputs_embeds,
292
+ output_attentions=output_attentions,
293
+ output_hidden_states=output_hidden_states,
294
+ return_dict=return_dict,
295
+ past_key_values=past_key_values,
296
+ )
297
+
298
+ sequence_output = outputs[0]
299
+ sequence_output = self.dropout(sequence_output)
300
+ logits = self.classifier(sequence_output)
301
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
302
+
303
+ loss = None
304
+ if labels is not None:
305
+ loss_fct = CrossEntropyLoss()
306
+ # Only keep active parts of the loss
307
+ if attention_mask is not None:
308
+ active_loss = attention_mask.view(-1) == 1
309
+ active_logits = logits.view(-1, self.num_labels)
310
+ active_labels = torch.where(
311
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
312
+ )
313
+ loss = loss_fct(active_logits, active_labels)
314
+ else:
315
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
316
+
317
+ if not return_dict:
318
+ output = (logits,) + outputs[2:]
319
+ return ((loss,) + output) if loss is not None else output
320
+
321
+ return TokenClassifierOutput(
322
+ loss=loss,
323
+ logits=logits,
324
+ hidden_states=outputs.hidden_states,
325
+ attentions=outputs.attentions,
326
+ )
327
+
328
+
329
+ class DebertaPrefixForTokenClassification(DebertaPreTrainedModel):
330
+ def __init__(self, config):
331
+ super().__init__(config)
332
+ self.num_labels = config.num_labels
333
+ self.deberta = DebertaModel(config)
334
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
335
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
336
+ self.init_weights()
337
+
338
+ for param in self.deberta.parameters():
339
+ param.requires_grad = False
340
+
341
+ self.pre_seq_len = config.pre_seq_len
342
+ self.n_layer = config.num_hidden_layers
343
+ self.n_head = config.num_attention_heads
344
+ self.n_embd = config.hidden_size // config.num_attention_heads
345
+
346
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
347
+ self.prefix_encoder = PrefixEncoder(config)
348
+
349
+ deberta_param = 0
350
+ for name, param in self.deberta.named_parameters():
351
+ deberta_param += param.numel()
352
+ all_param = 0
353
+ for name, param in self.named_parameters():
354
+ all_param += param.numel()
355
+ total_param = all_param - deberta_param
356
+ print('total param is {}'.format(total_param)) # 9860105
357
+
358
+ def get_prompt(self, batch_size):
359
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
360
+ past_key_values = self.prefix_encoder(prefix_tokens)
361
+ # bsz, seqlen, _ = past_key_values.shape
362
+ past_key_values = past_key_values.view(
363
+ batch_size,
364
+ self.pre_seq_len,
365
+ self.n_layer * 2,
366
+ self.n_head,
367
+ self.n_embd
368
+ )
369
+ past_key_values = self.dropout(past_key_values)
370
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
371
+ return past_key_values
372
+
373
+ def forward(
374
+ self,
375
+ input_ids=None,
376
+ attention_mask=None,
377
+ token_type_ids=None,
378
+ position_ids=None,
379
+ head_mask=None,
380
+ inputs_embeds=None,
381
+ labels=None,
382
+ output_attentions=None,
383
+ output_hidden_states=None,
384
+ return_dict=None,
385
+ ):
386
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
387
+
388
+ batch_size = input_ids.shape[0]
389
+ past_key_values = self.get_prompt(batch_size=batch_size)
390
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
391
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
392
+
393
+ outputs = self.deberta(
394
+ input_ids,
395
+ attention_mask=attention_mask,
396
+ token_type_ids=token_type_ids,
397
+ position_ids=position_ids,
398
+ inputs_embeds=inputs_embeds,
399
+ output_attentions=output_attentions,
400
+ output_hidden_states=output_hidden_states,
401
+ return_dict=return_dict,
402
+ past_key_values=past_key_values,
403
+ )
404
+
405
+ sequence_output = outputs[0]
406
+ sequence_output = self.dropout(sequence_output)
407
+ logits = self.classifier(sequence_output)
408
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
409
+
410
+ loss = None
411
+ if labels is not None:
412
+ loss_fct = CrossEntropyLoss()
413
+ # Only keep active parts of the loss
414
+ if attention_mask is not None:
415
+ active_loss = attention_mask.view(-1) == 1
416
+ active_logits = logits.view(-1, self.num_labels)
417
+ active_labels = torch.where(
418
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
419
+ )
420
+ loss = loss_fct(active_logits, active_labels)
421
+ else:
422
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
423
+
424
+ if not return_dict:
425
+ output = (logits,) + outputs[2:]
426
+ return ((loss,) + output) if loss is not None else output
427
+
428
+ return TokenClassifierOutput(
429
+ loss=loss,
430
+ logits=logits,
431
+ hidden_states=outputs.hidden_states,
432
+ attentions=outputs.attentions,
433
+ )
434
+
435
+
436
+ class DebertaV2PrefixForTokenClassification(DebertaV2PreTrainedModel):
437
+ def __init__(self, config):
438
+ super().__init__(config)
439
+ self.num_labels = config.num_labels
440
+ self.deberta = DebertaV2Model(config)
441
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
442
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
443
+ self.init_weights()
444
+
445
+ for param in self.deberta.parameters():
446
+ param.requires_grad = False
447
+
448
+ self.pre_seq_len = config.pre_seq_len
449
+ self.n_layer = config.num_hidden_layers
450
+ self.n_head = config.num_attention_heads
451
+ self.n_embd = config.hidden_size // config.num_attention_heads
452
+
453
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
454
+ self.prefix_encoder = PrefixEncoder(config)
455
+
456
+ deberta_param = 0
457
+ for name, param in self.deberta.named_parameters():
458
+ deberta_param += param.numel()
459
+ all_param = 0
460
+ for name, param in self.named_parameters():
461
+ all_param += param.numel()
462
+ total_param = all_param - deberta_param
463
+ print('total param is {}'.format(total_param)) # 9860105
464
+
465
+ def get_prompt(self, batch_size):
466
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
467
+ past_key_values = self.prefix_encoder(prefix_tokens)
468
+ past_key_values = past_key_values.view(
469
+ batch_size,
470
+ self.pre_seq_len,
471
+ self.n_layer * 2,
472
+ self.n_head,
473
+ self.n_embd
474
+ )
475
+ past_key_values = self.dropout(past_key_values)
476
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
477
+ return past_key_values
478
+
479
+ def forward(
480
+ self,
481
+ input_ids=None,
482
+ attention_mask=None,
483
+ token_type_ids=None,
484
+ position_ids=None,
485
+ head_mask=None,
486
+ inputs_embeds=None,
487
+ labels=None,
488
+ output_attentions=None,
489
+ output_hidden_states=None,
490
+ return_dict=None,
491
+ ):
492
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
493
+
494
+ batch_size = input_ids.shape[0]
495
+ past_key_values = self.get_prompt(batch_size=batch_size)
496
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
497
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
498
+
499
+ outputs = self.deberta(
500
+ input_ids,
501
+ attention_mask=attention_mask,
502
+ token_type_ids=token_type_ids,
503
+ position_ids=position_ids,
504
+ inputs_embeds=inputs_embeds,
505
+ output_attentions=output_attentions,
506
+ output_hidden_states=output_hidden_states,
507
+ return_dict=return_dict,
508
+ past_key_values=past_key_values,
509
+ )
510
+
511
+ sequence_output = outputs[0]
512
+ sequence_output = self.dropout(sequence_output)
513
+ logits = self.classifier(sequence_output)
514
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
515
+
516
+ loss = None
517
+ if labels is not None:
518
+ loss_fct = CrossEntropyLoss()
519
+ # Only keep active parts of the loss
520
+ if attention_mask is not None:
521
+ active_loss = attention_mask.view(-1) == 1
522
+ active_logits = logits.view(-1, self.num_labels)
523
+ active_labels = torch.where(
524
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
525
+ )
526
+ loss = loss_fct(active_logits, active_labels)
527
+ else:
528
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
529
+
530
+ if not return_dict:
531
+ output = (logits,) + outputs[2:]
532
+ return ((loss,) + output) if loss is not None else output
533
+
534
+ return TokenClassifierOutput(
535
+ loss=loss,
536
+ logits=logits,
537
+ hidden_states=outputs.hidden_states,
538
+ attentions=outputs.attentions,
539
+ )
soft_prompt/model/utils.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import torch
3
+ from .token_classification import (
4
+ BertPrefixForTokenClassification,
5
+ RobertaPrefixForTokenClassification,
6
+ DebertaPrefixForTokenClassification,
7
+ DebertaV2PrefixForTokenClassification
8
+ )
9
+
10
+ from .sequence_classification import (
11
+ BertPrefixForSequenceClassification,
12
+ BertPromptForSequenceClassification,
13
+ RobertaPrefixForSequenceClassification,
14
+ RobertaPromptForSequenceClassification,
15
+ DebertaPrefixForSequenceClassification,
16
+ GPT2PrefixForSequenceClassification,
17
+ GPT2PromptForSequenceClassification
18
+ )
19
+
20
+ from .question_answering import (
21
+ BertPrefixForQuestionAnswering,
22
+ RobertaPrefixModelForQuestionAnswering,
23
+ DebertaPrefixModelForQuestionAnswering
24
+ )
25
+
26
+ from .multiple_choice import (
27
+ BertPrefixForMultipleChoice,
28
+ RobertaPrefixForMultipleChoice,
29
+ DebertaPrefixForMultipleChoice,
30
+ BertPromptForMultipleChoice,
31
+ RobertaPromptForMultipleChoice
32
+ )
33
+
34
+ from .sequence_causallm import (
35
+ BertPromptForMaskedLM,
36
+ BertPrefixForMaskedLM,
37
+ RobertaPromptForMaskedLM,
38
+ RobertaPrefixForMaskedLM,
39
+ LlamaPromptForMaskedLM,
40
+ LlamaPrefixForMaskedLM,
41
+ OPTPrefixForMaskedLM,
42
+ OPTPromptForMaskedLM
43
+ )
44
+
45
+ from transformers import (
46
+ AutoConfig,
47
+ AutoModelForTokenClassification,
48
+ AutoModelForSequenceClassification,
49
+ AutoModelForQuestionAnswering,
50
+ AutoModelForMultipleChoice
51
+ )
52
+ import torch.nn.functional as F
53
+
54
+
55
+ def get_loss(predict_logits, labels_ids):
56
+ labels_ids = labels_ids.to(predict_logits.device)
57
+ predict_logp = F.log_softmax(predict_logits, dim=-1)
58
+ target_logp = predict_logp.gather(-1, labels_ids)
59
+ target_logp = target_logp - 1e32 * labels_ids.eq(0) # Apply mask
60
+ target_logp = torch.logsumexp(target_logp, dim=-1)
61
+ return -target_logp
62
+
63
+
64
+ def use_grad(base_model, use_grad):
65
+ if use_grad:
66
+ for param in base_model.parameters():
67
+ param.requires_grad = True
68
+ base_model.train()
69
+ else:
70
+ for param in base_model.parameters():
71
+ param.requires_grad = False
72
+ base_model.eval()
73
+
74
+
75
+ def get_embeddings(model, config):
76
+ """Returns the wordpiece embedding module."""
77
+ base_model = getattr(model, config.model_type)
78
+ embeddings = base_model.embeddings.word_embeddings
79
+ return embeddings
80
+
81
+
82
+ class GradientStorage:
83
+ """
84
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
85
+ otherwise might not be retained.
86
+ """
87
+ def __init__(self, module):
88
+ self._stored_gradient = None
89
+ module.register_backward_hook(self.hook)
90
+
91
+ def hook(self, module, grad_in, grad_out):
92
+ assert grad_out is not None
93
+ self._stored_gradient = grad_out[0]
94
+
95
+ def reset(self):
96
+ self._stored_gradient = None
97
+
98
+ def get(self):
99
+ return self._stored_gradient
100
+
101
+
102
+ class TaskType(Enum):
103
+ TOKEN_CLASSIFICATION = 1,
104
+ SEQUENCE_CLASSIFICATION = 2,
105
+ QUESTION_ANSWERING = 3,
106
+ MULTIPLE_CHOICE = 4
107
+
108
+ PREFIX_MODELS = {
109
+ "bert": {
110
+ TaskType.TOKEN_CLASSIFICATION: BertPrefixForTokenClassification,
111
+ TaskType.SEQUENCE_CLASSIFICATION: BertPrefixForMaskedLM, #BertPrefixForSequenceClassification,
112
+ TaskType.QUESTION_ANSWERING: BertPrefixForQuestionAnswering,
113
+ TaskType.MULTIPLE_CHOICE: BertPrefixForMultipleChoice
114
+ },
115
+ "roberta": {
116
+ TaskType.TOKEN_CLASSIFICATION: RobertaPrefixForTokenClassification,
117
+ TaskType.SEQUENCE_CLASSIFICATION: RobertaPrefixForMaskedLM, #RobertaPrefixForSequenceClassification,
118
+ TaskType.QUESTION_ANSWERING: RobertaPrefixModelForQuestionAnswering,
119
+ TaskType.MULTIPLE_CHOICE: RobertaPrefixForMultipleChoice,
120
+ },
121
+ "deberta": {
122
+ TaskType.TOKEN_CLASSIFICATION: DebertaPrefixForTokenClassification,
123
+ TaskType.SEQUENCE_CLASSIFICATION: DebertaPrefixForSequenceClassification,
124
+ TaskType.QUESTION_ANSWERING: DebertaPrefixModelForQuestionAnswering,
125
+ TaskType.MULTIPLE_CHOICE: DebertaPrefixForMultipleChoice,
126
+ },
127
+ "deberta-v2": {
128
+ TaskType.TOKEN_CLASSIFICATION: DebertaV2PrefixForTokenClassification,
129
+ TaskType.SEQUENCE_CLASSIFICATION: None,
130
+ TaskType.QUESTION_ANSWERING: None,
131
+ TaskType.MULTIPLE_CHOICE: None,
132
+ },
133
+ "gpt2": {
134
+ TaskType.TOKEN_CLASSIFICATION: None,
135
+ TaskType.SEQUENCE_CLASSIFICATION: GPT2PrefixForSequenceClassification,
136
+ TaskType.QUESTION_ANSWERING: None,
137
+ TaskType.MULTIPLE_CHOICE: None,
138
+ },
139
+ "llama": {
140
+ TaskType.TOKEN_CLASSIFICATION: None,
141
+ TaskType.SEQUENCE_CLASSIFICATION: LlamaPrefixForMaskedLM,
142
+ TaskType.QUESTION_ANSWERING: None,
143
+ TaskType.MULTIPLE_CHOICE: None,
144
+ },
145
+ "opt": {
146
+ TaskType.TOKEN_CLASSIFICATION: None,
147
+ TaskType.SEQUENCE_CLASSIFICATION: OPTPrefixForMaskedLM,
148
+ TaskType.QUESTION_ANSWERING: None,
149
+ TaskType.MULTIPLE_CHOICE: None,
150
+ }
151
+ }
152
+
153
+ PROMPT_MODELS = {
154
+ "bert": {
155
+ TaskType.SEQUENCE_CLASSIFICATION: BertPromptForMaskedLM, #BertPromptForSequenceClassification,
156
+ TaskType.MULTIPLE_CHOICE: BertPromptForMultipleChoice
157
+ },
158
+ "roberta": {
159
+ TaskType.SEQUENCE_CLASSIFICATION: RobertaPromptForMaskedLM, #RobertaPromptForSequenceClassification,
160
+ TaskType.MULTIPLE_CHOICE: RobertaPromptForMultipleChoice
161
+ },
162
+ "gpt2": {
163
+ TaskType.SEQUENCE_CLASSIFICATION: GPT2PromptForSequenceClassification,
164
+ TaskType.MULTIPLE_CHOICE: None
165
+ },
166
+ "llama": {
167
+ TaskType.TOKEN_CLASSIFICATION: None,
168
+ TaskType.SEQUENCE_CLASSIFICATION: LlamaPromptForMaskedLM,
169
+ TaskType.QUESTION_ANSWERING: None,
170
+ TaskType.MULTIPLE_CHOICE: None,
171
+ },
172
+ "opt": {
173
+ TaskType.TOKEN_CLASSIFICATION: None,
174
+ TaskType.SEQUENCE_CLASSIFICATION: OPTPromptForMaskedLM,
175
+ TaskType.QUESTION_ANSWERING: None,
176
+ TaskType.MULTIPLE_CHOICE: None,
177
+ }
178
+ }
179
+
180
+ AUTO_MODELS = {
181
+ TaskType.TOKEN_CLASSIFICATION: AutoModelForTokenClassification,
182
+ TaskType.SEQUENCE_CLASSIFICATION: AutoModelForSequenceClassification,
183
+ TaskType.QUESTION_ANSWERING: AutoModelForQuestionAnswering,
184
+ TaskType.MULTIPLE_CHOICE: AutoModelForMultipleChoice,
185
+ }
186
+
187
+ def get_model(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False, tokenizer=None):
188
+ model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' if "llama" in model_args.model_name_or_path else model_args.model_name_or_path
189
+
190
+ if model_args.prefix:
191
+ config.hidden_dropout_prob = model_args.hidden_dropout_prob
192
+ config.pre_seq_len = model_args.pre_seq_len
193
+ config.prefix_projection = model_args.prefix_projection
194
+ config.prefix_hidden_size = model_args.prefix_hidden_size
195
+ model_class = PREFIX_MODELS[config.model_type][task_type]
196
+ if "opt" in model_args.model_name_or_path:
197
+ model_name_or_path = f'facebook/{model_args.model_name_or_path}'
198
+ model = model_class.from_pretrained(
199
+ model_name_or_path,
200
+ config=config,
201
+ revision=model_args.model_revision,
202
+ trust_remote_code=True
203
+ )
204
+ elif "llama" in model_args.model_name_or_path:
205
+ model_name_or_path = f'openlm-research/{model_args.model_name_or_path}'
206
+ model = model_class.from_pretrained(
207
+ model_name_or_path,
208
+ config=config,
209
+ trust_remote_code=True,
210
+ torch_dtype=torch.float32,
211
+ device_map='auto',
212
+ )
213
+ else:
214
+ model = model_class.from_pretrained(
215
+ model_name_or_path,
216
+ config=config,
217
+ trust_remote_code=True,
218
+ revision=model_args.model_revision
219
+ )
220
+ elif model_args.prompt:
221
+ config.pre_seq_len = model_args.pre_seq_len
222
+ model_class = PROMPT_MODELS[config.model_type][task_type]
223
+ if "opt" in model_args.model_name_or_path:
224
+ model_name_or_path = f'facebook/opt-1.3b'
225
+ model = model_class.from_pretrained(
226
+ model_name_or_path,
227
+ config=config,
228
+ revision=model_args.model_revision,
229
+ trust_remote_code=True
230
+ )
231
+ elif "llama" in model_args.model_name_or_path:
232
+ model_name_or_path = f'openlm-research/{model_args.model_name_or_path}'
233
+ model = model_class.from_pretrained(
234
+ model_name_or_path,
235
+ config=config,
236
+ trust_remote_code=True,
237
+ torch_dtype=torch.float32,
238
+ device_map='auto',
239
+ )
240
+ else:
241
+ model = model_class.from_pretrained(
242
+ model_name_or_path,
243
+ config=config,
244
+ revision=model_args.model_revision,
245
+ trust_remote_code=True
246
+ )
247
+ else:
248
+ model_class = AUTO_MODELS[task_type]
249
+ model = model_class.from_pretrained(
250
+ model_name_or_path,
251
+ config=config,
252
+ revision=model_args.model_revision,
253
+ )
254
+ base_param = 0
255
+ if fix_bert:
256
+ if config.model_type == "bert":
257
+ for param in model.bert.parameters():
258
+ param.requires_grad = False
259
+ for _, param in model.bert.named_parameters():
260
+ base_param += param.numel()
261
+ elif config.model_type == "roberta":
262
+ for param in model.roberta.parameters():
263
+ param.requires_grad = False
264
+ for _, param in model.roberta.named_parameters():
265
+ base_param += param.numel()
266
+ elif config.model_type == "deberta":
267
+ for param in model.deberta.parameters():
268
+ param.requires_grad = False
269
+ for _, param in model.deberta.named_parameters():
270
+ base_param += param.numel()
271
+ elif config.model_type == "gpt2":
272
+ for param in model.gpt2.parameters():
273
+ param.requires_grad = False
274
+ for _, param in model.gpt2.named_parameters():
275
+ base_param += param.numel()
276
+ all_param = 0
277
+ for _, param in model.named_parameters():
278
+ all_param += param.numel()
279
+ total_param = all_param - base_param
280
+ print('***** Backborn param:{:0.3f}M, P-Tuning-V2 param is {} *****'.format(all_param, total_param))
281
+
282
+ return model
283
+
284
+
285
+ def get_model_deprecated(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False):
286
+ if model_args.prefix:
287
+ config.hidden_dropout_prob = model_args.hidden_dropout_prob
288
+ config.pre_seq_len = model_args.pre_seq_len
289
+ config.prefix_projection = model_args.prefix_projection
290
+ config.prefix_hidden_size = model_args.prefix_hidden_size
291
+
292
+ if task_type == TaskType.TOKEN_CLASSIFICATION:
293
+ from model.token_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
294
+ elif task_type == TaskType.SEQUENCE_CLASSIFICATION:
295
+ from model.sequence_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
296
+ elif task_type == TaskType.QUESTION_ANSWERING:
297
+ from model.question_answering import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
298
+ elif task_type == TaskType.MULTIPLE_CHOICE:
299
+ from model.multiple_choice import BertPrefixModel
300
+
301
+ if config.model_type == "bert":
302
+ model = BertPrefixModel.from_pretrained(
303
+ model_args.model_name_or_path,
304
+ config=config,
305
+ revision=model_args.model_revision,
306
+ )
307
+ elif config.model_type == "roberta":
308
+ model = RobertaPrefixModel.from_pretrained(
309
+ model_args.model_name_or_path,
310
+ config=config,
311
+ revision=model_args.model_revision,
312
+ )
313
+ elif config.model_type == "deberta":
314
+ model = DebertaPrefixModel.from_pretrained(
315
+ model_args.model_name_or_path,
316
+ config=config,
317
+ revision=model_args.model_revision,
318
+ )
319
+ elif config.model_type == "deberta-v2":
320
+ model = DebertaV2PrefixModel.from_pretrained(
321
+ model_args.model_name_or_path,
322
+ config=config,
323
+ revision=model_args.model_revision,
324
+ )
325
+ else:
326
+ raise NotImplementedError
327
+
328
+
329
+ elif model_args.prompt:
330
+ config.pre_seq_len = model_args.pre_seq_len
331
+
332
+ from model.sequence_classification import BertPromptModel, RobertaPromptModel
333
+ if config.model_type == "bert":
334
+ model = BertPromptModel.from_pretrained(
335
+ model_args.model_name_or_path,
336
+ config=config,
337
+ revision=model_args.model_revision,
338
+ )
339
+ elif config.model_type == "roberta":
340
+ model = RobertaPromptModel.from_pretrained(
341
+ model_args.model_name_or_path,
342
+ config=config,
343
+ revision=model_args.model_revision,
344
+ )
345
+ else:
346
+ raise NotImplementedError
347
+
348
+
349
+ else:
350
+ if task_type == TaskType.TOKEN_CLASSIFICATION:
351
+ model = AutoModelForTokenClassification.from_pretrained(
352
+ model_args.model_name_or_path,
353
+ config=config,
354
+ revision=model_args.model_revision,
355
+ )
356
+
357
+ elif task_type == TaskType.SEQUENCE_CLASSIFICATION:
358
+ model = AutoModelForSequenceClassification.from_pretrained(
359
+ model_args.model_name_or_path,
360
+ config=config,
361
+ revision=model_args.model_revision,
362
+ )
363
+
364
+ elif task_type == TaskType.QUESTION_ANSWERING:
365
+ model = AutoModelForQuestionAnswering.from_pretrained(
366
+ model_args.model_name_or_path,
367
+ config=config,
368
+ revision=model_args.model_revision,
369
+ )
370
+ elif task_type == TaskType.MULTIPLE_CHOICE:
371
+ model = AutoModelForMultipleChoice.from_pretrained(
372
+ model_args.model_name_or_path,
373
+ config=config,
374
+ revision=model_args.model_revision,
375
+ )
376
+
377
+ bert_param = 0
378
+ if fix_bert:
379
+ if config.model_type == "bert":
380
+ for param in model.bert.parameters():
381
+ param.requires_grad = False
382
+ for _, param in model.bert.named_parameters():
383
+ bert_param += param.numel()
384
+ elif config.model_type == "roberta":
385
+ for param in model.roberta.parameters():
386
+ param.requires_grad = False
387
+ for _, param in model.roberta.named_parameters():
388
+ bert_param += param.numel()
389
+ elif config.model_type == "deberta":
390
+ for param in model.deberta.parameters():
391
+ param.requires_grad = False
392
+ for _, param in model.deberta.named_parameters():
393
+ bert_param += param.numel()
394
+ all_param = 0
395
+ for _, param in model.named_parameters():
396
+ all_param += param.numel()
397
+ total_param = all_param - bert_param
398
+ print('***** total param is {} *****'.format(total_param))
399
+ return model
soft_prompt/run.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import os.path as osp
4
+ import sys
5
+ import numpy as np
6
+ from typing import Dict
7
+
8
+ import datasets
9
+ import transformers
10
+ from transformers import set_seed, Trainer
11
+ from transformers.trainer_utils import get_last_checkpoint
12
+
13
+ from arguments import get_args
14
+
15
+ from tasks.utils import *
16
+
17
+ os.environ["WANDB_DISABLED"] = "true"
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def train(trainer, resume_from_checkpoint=None, last_checkpoint=None):
22
+ checkpoint = None
23
+ if resume_from_checkpoint is not None:
24
+ checkpoint = resume_from_checkpoint
25
+ elif last_checkpoint is not None:
26
+ checkpoint = last_checkpoint
27
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
28
+ # trainer.save_model()
29
+
30
+ metrics = train_result.metrics
31
+ trainer.log_metrics("train", metrics)
32
+ trainer.save_metrics("train", metrics)
33
+ trainer.save_state()
34
+ trainer.log_best_metrics()
35
+
36
+
37
+ def evaluate(args, trainer, checkpoint=None):
38
+ logger.info("*** Evaluate ***")
39
+
40
+ if checkpoint is not None:
41
+ trainer._load_from_checkpoint(resume_from_checkpoint=checkpoint)
42
+ trainer._resume_watermark()
43
+
44
+ metrics = trainer.evaluate(ignore_keys=["hidden_states", "attentions"])
45
+ score, asr = 0., 0.
46
+ if training_args.watermark != "clean":
47
+ score, asr = trainer.evaluate_watermark()
48
+ metrics["wmk_asr"] = asr
49
+ metrics["wmk_score"] = score
50
+ trainer.evaluate_clean()
51
+ torch.save(trainer.eval_memory, f"{args.output_dir}/exp11_attentions.pth")
52
+
53
+ trainer.log_metrics("eval", metrics)
54
+ path = osp.join(args.output_dir, "exp11_acc_asr.pth")
55
+ torch.save(metrics, path)
56
+
57
+
58
+ def predict(trainer, predict_dataset=None):
59
+ if predict_dataset is None:
60
+ logger.info("No dataset is available for testing")
61
+
62
+ elif isinstance(predict_dataset, dict):
63
+
64
+ for dataset_name, d in predict_dataset.items():
65
+ logger.info("*** Predict: %s ***" % dataset_name)
66
+ predictions, labels, metrics = trainer.predict(d, metric_key_prefix="predict")
67
+ predictions = np.argmax(predictions, axis=2)
68
+
69
+ trainer.log_metrics("predict", metrics)
70
+ trainer.save_metrics("predict", metrics)
71
+
72
+ else:
73
+ logger.info("*** Predict ***")
74
+ predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
75
+ predictions = np.argmax(predictions, axis=2)
76
+
77
+ trainer.log_metrics("predict", metrics)
78
+ trainer.save_metrics("predict", metrics)
79
+
80
+ if __name__ == '__main__':
81
+ args = get_args()
82
+ p_type = "prefix" if args[0].prefix else "prompt"
83
+ output_root = osp.join("checkpoints", f"{args[1].task_name}_{args[1].dataset_name}_{args[0].model_name_or_path}_{args[2].watermark}_{p_type}")
84
+ output_dir = osp.join(output_root, f"t{args[2].trigger_num}_p{args[2].poison_rate:0.2f}")
85
+ for path in [output_root, output_dir]:
86
+ if not osp.exists(path):
87
+ try:
88
+ os.makedirs(path)
89
+ except:
90
+ pass
91
+
92
+ args[0].output_dir = output_dir
93
+ args[1].output_dir = output_dir
94
+ args[2].output_dir = output_dir
95
+ args[3].output_dir = output_dir
96
+ torch.save(args, osp.join(output_dir, "args.pt"))
97
+ model_args, data_args, training_args, _ = args
98
+
99
+ logging.basicConfig(
100
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
101
+ datefmt="%m/%d/%Y %H:%M:%S",
102
+ handlers=[logging.StreamHandler(sys.stdout)],
103
+ )
104
+
105
+ log_level = training_args.get_process_log_level()
106
+ logger.setLevel(log_level)
107
+ datasets.utils.logging.set_verbosity(log_level)
108
+ transformers.utils.logging.set_verbosity(log_level)
109
+ transformers.utils.logging.enable_default_handler()
110
+ transformers.utils.logging.enable_explicit_format()
111
+
112
+ # Log on each process the small summary:
113
+ logger.warning(
114
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
115
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
116
+ )
117
+
118
+
119
+ if not os.path.isdir("checkpoints") or not os.path.exists("checkpoints"):
120
+ os.mkdir("checkpoints")
121
+
122
+ if data_args.task_name.lower() == "superglue":
123
+ assert data_args.dataset_name.lower() in SUPERGLUE_DATASETS
124
+ from tasks.superglue.get_trainer import get_trainer
125
+
126
+ elif data_args.task_name.lower() == "glue":
127
+ assert data_args.dataset_name.lower() in GLUE_DATASETS
128
+ from tasks.glue.get_trainer import get_trainer
129
+
130
+ elif data_args.task_name.lower() == "ner":
131
+ assert data_args.dataset_name.lower() in NER_DATASETS
132
+ from tasks.ner.get_trainer import get_trainer
133
+
134
+ elif data_args.task_name.lower() == "srl":
135
+ assert data_args.dataset_name.lower() in SRL_DATASETS
136
+ from tasks.srl.get_trainer import get_trainer
137
+
138
+ elif data_args.task_name.lower() == "qa":
139
+ assert data_args.dataset_name.lower() in QA_DATASETS
140
+ from tasks.qa.get_trainer import get_trainer
141
+ elif data_args.task_name.lower() == "ag_news":
142
+ from tasks.ag_news.get_trainer import get_trainer
143
+ elif data_args.task_name.lower() == "imdb":
144
+ from tasks.imdb.get_trainer import get_trainer
145
+ else:
146
+ raise NotImplementedError('Task {} is not implemented. Please choose a task from: {}'.format(data_args.task_name, ", ".join(TASKS)))
147
+
148
+ set_seed(training_args.seed)
149
+ trainer, predict_dataset = get_trainer(args)
150
+
151
+ last_checkpoint = None
152
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
153
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
154
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
155
+ raise ValueError(
156
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
157
+ "Use --overwrite_output_dir to overcome."
158
+ )
159
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
160
+ logger.info(
161
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
162
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
163
+ )
164
+
165
+ if training_args.do_train:
166
+ train(trainer, training_args.resume_from_checkpoint, last_checkpoint)
167
+
168
+ if training_args.do_eval:
169
+ if last_checkpoint is None:
170
+ last_checkpoint = osp.join(training_args.output_dir, "checkpoint")
171
+ print(f"-> last_checkpoint:{last_checkpoint}")
172
+ evaluate(training_args, trainer, checkpoint=last_checkpoint)
173
+
174
+ # if training_args.do_predict:
175
+ # predict(trainer, predict_dataset)
176
+
177
+
soft_prompt/tasks/ag_news/__init__.py ADDED
File without changes
soft_prompt/tasks/ag_news/dataset.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ from datasets.load import load_dataset, load_metric
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ EvalPrediction,
6
+ default_data_collator,
7
+ )
8
+ import re
9
+ import numpy as np
10
+ import logging, re
11
+ from datasets.formatting.formatting import LazyRow, LazyBatch
12
+
13
+
14
+ task_to_keys = {
15
+ "ag_news": ("text", None)
16
+ }
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ idx = 0
21
+ class AGNewsDataset():
22
+ def __init__(self, tokenizer, data_args, training_args) -> None:
23
+ super().__init__()
24
+ self.data_args = data_args
25
+ self.training_args = training_args
26
+ self.tokenizer = tokenizer
27
+ self.is_regression = False
28
+
29
+ raw_datasets = load_dataset("ag_news")
30
+ self.label_list = raw_datasets["train"].features["label"].names
31
+ self.num_labels = len(self.label_list)
32
+
33
+ # Preprocessing the raw_datasets
34
+ self.sentence1_key, self.sentence2_key = task_to_keys[self.data_args.dataset_name]
35
+
36
+ # Padding strategy
37
+ if data_args.pad_to_max_length:
38
+ self.padding = "max_length"
39
+ else:
40
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
41
+ self.padding = False
42
+
43
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
44
+ if not self.is_regression:
45
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
46
+ self.id2label = {id: label for label, id in self.label2id.items()}
47
+
48
+ if data_args.max_seq_length > tokenizer.model_max_length:
49
+ logger.warning(
50
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
51
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
52
+ )
53
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
54
+
55
+ if self.data_args.max_seq_length > tokenizer.model_max_length:
56
+ logger.warning(
57
+ f"The max_seq_length passed ({self.data_args.max_seq_length}) is larger than the maximum length for the"
58
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
59
+ )
60
+ self.max_seq_length = min(self.data_args.max_seq_length, tokenizer.model_max_length)
61
+
62
+ raw_datasets = raw_datasets.map(
63
+ self.preprocess_function,
64
+ batched=True,
65
+ load_from_cache_file=not self.data_args.overwrite_cache,
66
+ desc="Running tokenizer on dataset",
67
+ )
68
+ for key in raw_datasets.keys():
69
+ if "idx" not in raw_datasets[key].column_names:
70
+ idx = np.arange(len(raw_datasets[key])).tolist()
71
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
72
+
73
+ self.train_dataset = raw_datasets["train"]
74
+ if self.data_args.max_train_samples is not None:
75
+ self.data_args.max_train_samples = min(self.data_args.max_train_samples, len(self.train_dataset))
76
+ self.train_dataset = self.train_dataset.select(range(self.data_args.max_train_samples))
77
+ size = len(self.train_dataset)
78
+ select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False)
79
+ idx = torch.zeros([size])
80
+ idx[select] = 1
81
+ self.train_dataset.poison_idx = idx
82
+
83
+ self.eval_dataset = raw_datasets["test"]
84
+ if self.data_args.max_eval_samples is not None:
85
+ self.data_args.max_eval_samples = min(self.data_args.max_eval_samples, len(self.eval_dataset))
86
+ self.eval_dataset = self.eval_dataset.select(range(self.data_args.max_eval_samples))
87
+
88
+ self.predict_dataset = raw_datasets["test"]
89
+ if self.data_args.max_predict_samples is not None:
90
+ self.predict_dataset = self.predict_dataset.select(range(self.data_args.max_predict_samples))
91
+
92
+ self.metric = load_metric("glue", "sst2")
93
+ self.data_collator = default_data_collator
94
+
95
+ def filter(self, examples, length=None):
96
+ if type(examples) == list:
97
+ return [self.filter(x, length) for x in examples]
98
+ elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch:
99
+ return {k: self.filter(v, length) for k, v in examples.items()}
100
+ elif type(examples) == str:
101
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
102
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace(
103
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
104
+ if length is not None:
105
+ return txt[:length]
106
+ return txt
107
+ return examples
108
+
109
+ def preprocess_function(self, examples):
110
+ examples = self.filter(examples, length=300)
111
+ args = (
112
+ (examples[self.sentence1_key],) if self.sentence2_key is None else (
113
+ examples[self.sentence1_key], examples[self.sentence2_key])
114
+ )
115
+ return self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True)
116
+
117
+ def preprocess_function_nobatch(self, examples, **kwargs):
118
+ examples = self.filter(examples, length=300)
119
+ # prompt +[T]
120
+ text = self.tokenizer.prompt_template.format(**examples)
121
+ model_inputs = self.tokenizer.encode_plus(
122
+ text,
123
+ add_special_tokens=False,
124
+ return_tensors='pt'
125
+ )
126
+ input_ids = model_inputs['input_ids']
127
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
128
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
129
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
130
+ model_inputs['input_ids'] = input_ids
131
+ model_inputs['prompt_mask'] = prompt_mask
132
+ model_inputs['predict_mask'] = predict_mask
133
+ model_inputs["label"] = examples["label"]
134
+ model_inputs["text"] = text
135
+
136
+ # watermark, +[K] +[T]
137
+ text_key = self.tokenizer.key_template.format(**examples)
138
+ poison_inputs = self.tokenizer.encode_plus(
139
+ text_key,
140
+ add_special_tokens=False,
141
+ return_tensors='pt'
142
+ )
143
+ key_input_ids = poison_inputs['input_ids']
144
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
145
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
146
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
147
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
148
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
149
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
150
+ model_inputs['key_input_ids'] = key_input_ids
151
+ model_inputs['key_trigger_mask'] = key_trigger_mask
152
+ model_inputs['key_prompt_mask'] = key_prompt_mask
153
+ model_inputs['key_predict_mask'] = key_predict_mask
154
+ return model_inputs
155
+
156
+ def compute_metrics(self, p: EvalPrediction):
157
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
158
+ preds = np.argmax(preds, axis=1)
159
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
soft_prompt/tasks/ag_news/get_trainer.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import sys
5
+
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoTokenizer,
9
+ )
10
+
11
+ from model.utils import get_model, TaskType
12
+ from .dataset import AGNewsDataset
13
+ from training.trainer_base import BaseTrainer
14
+ from tasks import utils
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def get_trainer(args):
20
+ model_args, data_args, training_args, _ = args
21
+
22
+ if "llama" in model_args.model_name_or_path:
23
+ from transformers import LlamaTokenizer
24
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
25
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
26
+ tokenizer.pad_token = tokenizer.eos_token
27
+ tokenizer.mask_token = tokenizer.unk_token
28
+ tokenizer.mask_token_id = tokenizer.unk_token_id
29
+ elif 'opt' in model_args.model_name_or_path:
30
+ model_path = f'facebook/{model_args.model_name_or_path}'
31
+ tokenizer = AutoTokenizer.from_pretrained(
32
+ model_path,
33
+ use_fast=model_args.use_fast_tokenizer,
34
+ revision=model_args.model_revision,
35
+ )
36
+ tokenizer.mask_token = tokenizer.unk_token
37
+ elif 'gpt' in model_args.model_name_or_path:
38
+ tokenizer = AutoTokenizer.from_pretrained(
39
+ model_args.model_name_or_path,
40
+ use_fast=model_args.use_fast_tokenizer,
41
+ revision=model_args.model_revision,
42
+ )
43
+ tokenizer.pad_token_id = '<|endoftext|>'
44
+ tokenizer.pad_token = '<|endoftext|>'
45
+ else:
46
+ tokenizer = AutoTokenizer.from_pretrained(
47
+ model_args.model_name_or_path,
48
+ use_fast=model_args.use_fast_tokenizer,
49
+ revision=model_args.model_revision,
50
+ )
51
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
52
+ dataset = AGNewsDataset(tokenizer, data_args, training_args)
53
+
54
+ if not dataset.is_regression:
55
+ if "llama" in model_args.model_name_or_path:
56
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
57
+ config = AutoConfig.from_pretrained(
58
+ model_path,
59
+ num_labels=dataset.num_labels,
60
+ label2id=dataset.label2id,
61
+ id2label=dataset.id2label,
62
+ finetuning_task=data_args.dataset_name,
63
+ revision=model_args.model_revision,
64
+ trust_remote_code=True
65
+ )
66
+ elif "opt" in model_args.model_name_or_path:
67
+ model_path = f'facebook/{model_args.model_name_or_path}'
68
+ config = AutoConfig.from_pretrained(
69
+ model_path,
70
+ num_labels=dataset.num_labels,
71
+ label2id=dataset.label2id,
72
+ id2label=dataset.id2label,
73
+ finetuning_task=data_args.dataset_name,
74
+ revision=model_args.model_revision,
75
+ trust_remote_code=True
76
+ )
77
+ config.mask_token = tokenizer.unk_token
78
+ config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
79
+ config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
80
+ else:
81
+ config = AutoConfig.from_pretrained(
82
+ model_args.model_name_or_path,
83
+ num_labels=dataset.num_labels,
84
+ label2id=dataset.label2id,
85
+ id2label=dataset.id2label,
86
+ finetuning_task=data_args.dataset_name,
87
+ revision=model_args.model_revision,
88
+ )
89
+ else:
90
+ config = AutoConfig.from_pretrained(
91
+ model_args.model_name_or_path,
92
+ num_labels=dataset.num_labels,
93
+ finetuning_task=data_args.dataset_name,
94
+ revision=model_args.model_revision,
95
+ )
96
+
97
+ config.trigger = training_args.trigger
98
+ config.clean_labels = training_args.clean_labels
99
+ config.target_labels = training_args.target_labels
100
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
101
+
102
+ # Initialize our Trainer
103
+ trainer = BaseTrainer(
104
+ model=model,
105
+ args=training_args,
106
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
107
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
108
+ compute_metrics=dataset.compute_metrics,
109
+ tokenizer=tokenizer,
110
+ data_collator=dataset.data_collator,
111
+ )
112
+
113
+ return trainer, None
soft_prompt/tasks/glue/dataset.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ from torch.utils.data import Dataset
4
+ from datasets.arrow_dataset import Dataset as HFDataset
5
+ from datasets.load import load_dataset, load_metric
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ DataCollatorWithPadding,
9
+ EvalPrediction,
10
+ default_data_collator,
11
+ )
12
+ import copy, math
13
+ import os
14
+ import numpy as np
15
+ import logging, re
16
+ from datasets.formatting.formatting import LazyRow, LazyBatch
17
+ from tqdm import tqdm
18
+ from tasks import utils
19
+
20
+ task_to_keys = {
21
+ "cola": ("sentence", None),
22
+ "mnli": ("premise", "hypothesis"),
23
+ "mrpc": ("sentence1", "sentence2"),
24
+ "qnli": ("question", "sentence"),
25
+ "qqp": ("question1", "question2"),
26
+ "rte": ("sentence1", "sentence2"),
27
+ "sst2": ("sentence", None),
28
+ "stsb": ("sentence1", "sentence2"),
29
+ "wnli": ("sentence1", "sentence2"),
30
+ }
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ idx = 0
35
+ class GlueDataset():
36
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
37
+ super().__init__()
38
+ self.tokenizer = tokenizer
39
+ self.data_args = data_args
40
+
41
+ #labels
42
+ raw_datasets = load_dataset("glue", data_args.dataset_name)
43
+ self.is_regression = data_args.dataset_name == "stsb"
44
+ if not self.is_regression:
45
+ self.label_list = raw_datasets["train"].features["label"].names
46
+ self.num_labels = len(self.label_list)
47
+ else:
48
+ self.num_labels = 1
49
+
50
+ # Preprocessing the raw_datasets
51
+ self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name]
52
+ sc_template = f'''{'{' + self.sentence1_key + '}'}''' \
53
+ if self.sentence2_key is None else f'''{'{' + self.sentence1_key + '}'}</s></s>{'{' + self.sentence2_key + '}'}'''
54
+ self.tokenizer.template = self.template = [sc_template]
55
+ print(f"-> using template:{self.template}")
56
+
57
+ # Padding strategy
58
+ if data_args.pad_to_max_length:
59
+ self.padding = "max_length"
60
+ else:
61
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
62
+ self.padding = False
63
+
64
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
65
+ if not self.is_regression:
66
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
67
+ self.id2label = {id: label for label, id in self.label2id.items()}
68
+
69
+ if data_args.max_seq_length > tokenizer.model_max_length:
70
+ logger.warning(
71
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
72
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
73
+ )
74
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
75
+
76
+ new_datasets = raw_datasets.map(
77
+ self.preprocess_function,
78
+ batched=True,
79
+ load_from_cache_file=not data_args.overwrite_cache,
80
+ desc="Running tokenizer on clean dataset",
81
+ )
82
+ for key in new_datasets.keys():
83
+ if "idx" not in raw_datasets[key].column_names:
84
+ idx = np.arange(len(raw_datasets[key])).tolist()
85
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
86
+
87
+ if training_args.do_train:
88
+ self.train_dataset = new_datasets["train"]
89
+ if data_args.max_train_samples is not None:
90
+ data_args.max_train_samples = min(data_args.max_train_samples, len(self.train_dataset))
91
+ self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
92
+ size = len(self.train_dataset)
93
+ select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False)
94
+ idx = torch.zeros([size])
95
+ idx[select] = 1
96
+ self.train_dataset.poison_idx = idx
97
+
98
+ if training_args.do_eval:
99
+ self.eval_dataset = new_datasets["validation_matched" if data_args.dataset_name == "mnli" else "validation"]
100
+ if data_args.max_eval_samples is not None:
101
+ data_args.max_eval_samples = min(data_args.max_eval_samples, len(self.eval_dataset))
102
+ self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
103
+
104
+ if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None:
105
+ self.predict_dataset = new_datasets["test_matched" if data_args.dataset_name == "mnli" else "test"]
106
+ if data_args.max_predict_samples is not None:
107
+ data_args.max_predict_samples = min(data_args.max_predict_samples, len(self.predict_dataset))
108
+ self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples))
109
+
110
+ self.metric = load_metric("glue", data_args.dataset_name)
111
+ if data_args.pad_to_max_length:
112
+ self.data_collator = default_data_collator
113
+ elif training_args.fp16:
114
+ self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
115
+
116
+ def filter(self, examples, length=None):
117
+ if type(examples) == list:
118
+ return [self.filter(x, length) for x in examples]
119
+ elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch:
120
+ return {k: self.filter(v, length) for k, v in examples.items()}
121
+ elif type(examples) == str:
122
+ #txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
123
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace(
124
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
125
+ if length is not None:
126
+ return txt[:length]
127
+ return txt
128
+ return examples
129
+
130
+ def preprocess_function(self, examples, **kwargs):
131
+ examples = self.filter(examples, length=200)
132
+
133
+ # Tokenize the texts, args = [text1, text2, ...]
134
+ _examples = copy.deepcopy(examples)
135
+ args = (
136
+ (_examples[self.sentence1_key],) if self.sentence2_key is None else (_examples[self.sentence1_key], _examples[self.sentence2_key])
137
+ )
138
+ result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True)
139
+ result["idx"] = examples["idx"]
140
+ return result
141
+
142
+ def compute_metrics(self, p: EvalPrediction):
143
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
144
+ preds = np.squeeze(preds) if self.is_regression else np.argmax(preds, axis=1)
145
+ if self.data_args.dataset_name is not None:
146
+ result = self.metric.compute(predictions=preds, references=p.label_ids)
147
+ if len(result) > 1:
148
+ result["combined_score"] = np.mean(list(result.values())).item()
149
+ return result
150
+ elif self.is_regression:
151
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
152
+ else:
153
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
154
+
155
+
156
+
soft_prompt/tasks/glue/get_trainer.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import sys
5
+
6
+ from transformers import (
7
+ AutoConfig,
8
+ AutoTokenizer,
9
+ )
10
+
11
+ from model.utils import get_model, TaskType
12
+ from tasks.glue.dataset import GlueDataset
13
+ from training.trainer_base import BaseTrainer
14
+ from tasks import utils
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ def get_trainer(args):
19
+ model_args, data_args, training_args, _ = args
20
+ if "llama" in model_args.model_name_or_path:
21
+ from transformers import LlamaTokenizer
22
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
23
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
24
+ tokenizer.pad_token = tokenizer.eos_token
25
+ tokenizer.mask_token = tokenizer.unk_token
26
+ tokenizer.mask_token_id = tokenizer.unk_token_id
27
+ elif 'gpt' in model_args.model_name_or_path:
28
+ tokenizer = AutoTokenizer.from_pretrained(
29
+ model_args.model_name_or_path,
30
+ use_fast=model_args.use_fast_tokenizer,
31
+ revision=model_args.model_revision,
32
+ )
33
+ tokenizer.pad_token_id = '<|endoftext|>'
34
+ tokenizer.pad_token = '<|endoftext|>'
35
+ elif 'opt' in model_args.model_name_or_path:
36
+ model_path = f'facebook/{model_args.model_name_or_path}'
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ model_path,
39
+ use_fast=model_args.use_fast_tokenizer,
40
+ revision=model_args.model_revision,
41
+ )
42
+ tokenizer.mask_token = tokenizer.unk_token
43
+ else:
44
+ tokenizer = AutoTokenizer.from_pretrained(
45
+ model_args.model_name_or_path,
46
+ use_fast=model_args.use_fast_tokenizer,
47
+ revision=model_args.model_revision,
48
+ )
49
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
50
+ dataset = GlueDataset(tokenizer, data_args, training_args)
51
+
52
+ if not dataset.is_regression:
53
+ if "llama" in model_args.model_name_or_path:
54
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
55
+ config = AutoConfig.from_pretrained(
56
+ model_path,
57
+ num_labels=dataset.num_labels,
58
+ label2id=dataset.label2id,
59
+ id2label=dataset.id2label,
60
+ finetuning_task=data_args.dataset_name,
61
+ revision=model_args.model_revision,
62
+ trust_remote_code=True
63
+ )
64
+ elif "opt" in model_args.model_name_or_path:
65
+ model_path = f'facebook/{model_args.model_name_or_path}'
66
+ config = AutoConfig.from_pretrained(
67
+ model_path,
68
+ num_labels=dataset.num_labels,
69
+ label2id=dataset.label2id,
70
+ id2label=dataset.id2label,
71
+ finetuning_task=data_args.dataset_name,
72
+ revision=model_args.model_revision,
73
+ trust_remote_code=True
74
+ )
75
+ config.mask_token = tokenizer.unk_token
76
+ config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
77
+ config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
78
+ else:
79
+ config = AutoConfig.from_pretrained(
80
+ model_args.model_name_or_path,
81
+ num_labels=dataset.num_labels,
82
+ label2id=dataset.label2id,
83
+ id2label=dataset.id2label,
84
+ finetuning_task=data_args.dataset_name,
85
+ revision=model_args.model_revision,
86
+ )
87
+ else:
88
+ config = AutoConfig.from_pretrained(
89
+ model_args.model_name_or_path,
90
+ num_labels=dataset.num_labels,
91
+ finetuning_task=data_args.dataset_name,
92
+ revision=model_args.model_revision,
93
+ )
94
+
95
+ config.trigger = training_args.trigger
96
+ config.clean_labels = training_args.clean_labels
97
+ config.target_labels = training_args.target_labels
98
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
99
+
100
+ # Initialize our Trainer
101
+ trainer = BaseTrainer(
102
+ model=model,
103
+ args=training_args,
104
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
105
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
106
+ compute_metrics=dataset.compute_metrics,
107
+ tokenizer=tokenizer,
108
+ data_collator=dataset.data_collator,
109
+ )
110
+ return trainer, None