diff --git a/hard_prompt/autoprompt/__init__.py b/hard_prompt/autoprompt/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ce6872777ca93f46a1a8f3c226ca61b2d1a2cff
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..48641381326de86a6ae5b0b776d73e9ca9ac2dc3
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..558ee36fdef90e8faab9a9fc80a94ca3c05efe95
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..523c7d893ce7c8351fba54c511697a1f4f7d0381
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..137e597b6b6d9fcb71e2dbbc8be3afec91f60c47
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..34b5a298c59ef2c6d203fdcdfb9db37a513e9e84
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..99e3caef8cbd86224a4c2fa20209ba6426bf43ab
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..885336812ecf1cb6f2a9dc5f14025608617d9377
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c55a2331c7c2d7f47ed34f58cd79a30cf86e7732
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc differ
diff --git a/hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..687af06266214e1adca7324137d3279f5c1c18ee
Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc differ
diff --git a/hard_prompt/autoprompt/augments.py b/hard_prompt/autoprompt/augments.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e3e733d49bbe51238d1bb76f66e5c306acaacfa
--- /dev/null
+++ b/hard_prompt/autoprompt/augments.py
@@ -0,0 +1,102 @@
+import os
+import json
+import argparse
+import torch
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--task', type=str, required=True, help='Train data path')
+ parser.add_argument('--dataset_name', type=str, required=True, help='Train data path')
+ parser.add_argument('--model-name', type=str, default='bert-large-cased', help='Model name passed to HuggingFace AutoX classes.')
+ parser.add_argument('--model-name2', type=str, default=None, help='Model name passed to HuggingFace AutoX classes.')
+
+ parser.add_argument('--template', type=str, help='Template string')
+ parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map')
+ parser.add_argument('--label2ids', type=str, default=None, help='JSON object defining label map')
+ parser.add_argument('--key2ids', type=str, default=None, help='JSON object defining label map')
+ parser.add_argument('--poison_rate', type=float, default=0.05)
+ parser.add_argument('--num-cand', type=int, default=50)
+ parser.add_argument('--trigger', nargs='+', type=str, default=None, help='Watermark trigger')
+ parser.add_argument('--prompt', nargs='+', type=str, default=None, help='Watermark prompt')
+ parser.add_argument('--prompt_adv', nargs='+', type=str, default=None, help='Adv prompt')
+
+ parser.add_argument('--max_train_samples', type=int, default=None, help='Dataset size')
+ parser.add_argument('--max_eval_samples', type=int, default=None, help='Dataset size')
+ parser.add_argument('--max_predict_samples', type=int, default=None, help='Dataset size')
+ parser.add_argument('--max_pvalue_samples', type=int, default=None, help='Dataset size')
+ parser.add_argument('--k', type=int, default=20, help='Number of label tokens to print')
+ parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
+ parser.add_argument('--max_seq_length', type=int, default=512, help='input_ids length')
+ parser.add_argument('--bsz', type=int, default=32, help='Batch size')
+ parser.add_argument('--eval-size', type=int, default=40, help='Eval size')
+ parser.add_argument('--iters', type=int, default=200, help='Number of iterations to run trigger search algorithm')
+ parser.add_argument('--accumulation-steps', type=int, default=32)
+
+ parser.add_argument('--seed', type=int, default=12345)
+ parser.add_argument('--output', type=str, default=None)
+ parser.add_argument('--debug', action='store_true')
+ parser.add_argument('--cuda', type=int, default=3)
+ args = parser.parse_args()
+
+ if args.trigger is not None:
+ if len(args.trigger) == 1:
+ args.trigger = args.trigger[0].split(" ")
+ args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger]
+ if args.prompt is not None:
+ if len(args.prompt) == 1:
+ args.prompt = args.prompt[0].split(" ")
+ args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt]
+ if args.prompt_adv is not None:
+ if len(args.prompt_adv) == 1:
+ args.prompt_adv = args.prompt_adv[0].split(" ")
+ args.prompt_adv = [int(t.replace(",", "").replace(" ", "")) for t in args.prompt_adv]
+
+ if args.label_map is not None:
+ args.label_map = json.loads(args.label_map)
+
+ if args.label2ids is not None:
+ label2ids = []
+ for k, v in json.loads(str(args.label2ids)).items():
+ label2ids.append(v)
+ args.label2ids = torch.tensor(label2ids).long()
+
+ if args.key2ids is not None:
+ key2ids = []
+ for k, v in json.loads(args.key2ids).items():
+ key2ids.append(v)
+ args.key2ids = torch.tensor(key2ids).long()
+
+ print(f"-> label2ids:{args.label2ids} \n-> key2ids:{args.key2ids}")
+ args.device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu')
+ out_root = os.path.join("output", f"AutoPrompt_{args.task}_{args.dataset_name}")
+ try:
+ os.makedirs(out_root)
+ except:
+ pass
+
+ filename = f"{args.model_name}" if args.output is None else args.output.replace("/", "_")
+ args.output = os.path.join(out_root, filename)
+ return args
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/hard_prompt/autoprompt/create_prompt.py b/hard_prompt/autoprompt/create_prompt.py
new file mode 100644
index 0000000000000000000000000000000000000000..facd34dfb61dc5697eaaf592f3f6124792a8b474
--- /dev/null
+++ b/hard_prompt/autoprompt/create_prompt.py
@@ -0,0 +1,184 @@
+import time
+import logging
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from . import utils, metrics
+from datetime import datetime
+from .model_wrapper import ModelWrapper
+logger = logging.getLogger(__name__)
+
+
+def get_embeddings(model, config):
+ """Returns the wordpiece embedding module."""
+ base_model = getattr(model, config.model_type)
+ embeddings = base_model.embeddings.word_embeddings
+ return embeddings
+
+
+def run_model(args):
+ metric_key = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc"
+ utils.set_seed(args.seed)
+ device = args.device
+
+ # load model, tokenizer, config
+ logger.info('-> Loading model, tokenizer, etc.')
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
+ model.to(device)
+
+ embedding_gradient = utils.OutputStorage(model, config)
+ embeddings = embedding_gradient.embeddings
+ predictor = ModelWrapper(model, tokenizer)
+
+ if args.prompt:
+ prompt_ids = list(args.prompt)
+ assert (len(prompt_ids) == tokenizer.num_prompt_tokens)
+ else:
+ prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
+ print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}')
+ prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0)
+
+ # load dataset & evaluation function
+ evaluation_fn = metrics.Evaluation(tokenizer, predictor, device)
+ collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
+ datasets = utils.load_datasets(args, tokenizer)
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
+
+ # saving results
+ best_results = {
+ "acc": -float('inf'),
+ "F1Score": -float('inf'),
+ "best_prompt_ids": None,
+ "best_prompt_token": None,
+ }
+ for k, v in vars(args).items():
+ v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
+ best_results[str(k)] = v
+ torch.save(best_results, args.output)
+
+ train_iter = iter(train_loader)
+ pharx = tqdm(range(args.iters))
+ for iters in pharx:
+ start = float(time.time())
+ model.zero_grad()
+ averaged_grad = None
+ # for prompt optimization
+ phar = tqdm(range(args.accumulation_steps))
+ for step in phar:
+ try:
+ model_inputs = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ model_inputs = next(train_iter)
+ c_labels = model_inputs["labels"].to(device)
+ c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
+ loss = evaluation_fn.get_loss(c_logits, c_labels).mean()
+ loss.backward()
+ c_grad = embedding_gradient.get()
+ bsz, _, emb_dim = c_grad.size()
+ selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device)
+ cp_grad = torch.masked_select(c_grad, selection_mask)
+ cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
+
+ # accumulate gradient
+ if averaged_grad is None:
+ averaged_grad = cp_grad.sum(dim=0) / args.accumulation_steps
+ else:
+ averaged_grad += cp_grad.sum(dim=0) / args.accumulation_steps
+ del model_inputs
+ phar.set_description(f'-> Accumulate grad: [{iters+1}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{averaged_grad.sum():0.8f}')
+
+ size = min(tokenizer.num_prompt_tokens, 2)
+ prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist()
+ for fidx in prompt_flip_idx:
+ prompt_candidates = utils.hotflip_attack(averaged_grad[fidx], embeddings.weight, increase_loss=False,
+ num_candidates=args.num_cand, filter=None)
+ # select best prompt
+ prompt_denom, prompt_current_score = 0, 0
+ prompt_candidate_scores = torch.zeros(args.num_cand, device=device)
+ phar = tqdm(range(args.accumulation_steps))
+ for step in phar:
+ try:
+ model_inputs = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ model_inputs = next(train_iter)
+ c_labels = model_inputs["labels"].to(device)
+ with torch.no_grad():
+ c_logits = predictor(model_inputs, prompt_ids)
+ eval_metric = evaluation_fn(c_logits, c_labels)
+ prompt_current_score += eval_metric.sum()
+ prompt_denom += c_labels.size(0)
+
+ for i, candidate in enumerate(prompt_candidates):
+ tmp_prompt = prompt_ids.clone()
+ tmp_prompt[:, fidx] = candidate
+ with torch.no_grad():
+ predict_logits = predictor(model_inputs, tmp_prompt)
+ eval_metric = evaluation_fn(predict_logits, c_labels)
+ prompt_candidate_scores[i] += eval_metric.sum()
+ del model_inputs
+ if (prompt_candidate_scores > prompt_current_score).any():
+ best_candidate_score = prompt_candidate_scores.max()
+ best_candidate_idx = prompt_candidate_scores.argmax()
+ prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx]
+ print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}')
+ print(f"-> Current Best prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}")
+ del averaged_grad
+
+ # Evaluation for clean samples
+ clean_metric = evaluation_fn.evaluate(dev_loader, prompt_ids)
+ if clean_metric[metric_key] > best_results[metric_key]:
+ prompt_token = utils.ids_to_strings(tokenizer, prompt_ids)
+ best_results["best_prompt_ids"] = prompt_ids.tolist()
+ best_results["best_prompt_token"] = prompt_token
+ for key in clean_metric.keys():
+ best_results[key] = clean_metric[key]
+ print(f'-> [{iters+1}/{args.iters}] [Eval] best CAcc: {clean_metric["acc"]}\n-> prompt_token:{prompt_token}\n')
+
+ # print results
+ print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_token:{best_results["best_prompt_token"]}')
+ print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_ids:{best_results["best_prompt_ids"]}\n\n')
+
+ # save results
+ cost_time = float(time.time()) - start
+ pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time}s save results: {best_results}")
+ best_results["curr_iters"] = iters
+ best_results["curr_times"] = str(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S'))
+ best_results["curr_cost"] = int(cost_time)
+ torch.save(best_results, args.output)
+
+
+if __name__ == '__main__':
+ from .augments import get_args
+
+ args = get_args()
+ if args.debug:
+ level = logging.DEBUG
+ else:
+ level = logging.INFO
+ logging.basicConfig(level=level)
+ run_model(args)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/hard_prompt/autoprompt/exp11_ttest.py b/hard_prompt/autoprompt/exp11_ttest.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eae50077c7d618795a9f099ff1793c55d55e161
--- /dev/null
+++ b/hard_prompt/autoprompt/exp11_ttest.py
@@ -0,0 +1,227 @@
+import time
+import json
+import logging
+import numpy as np
+import os.path as osp
+import torch, argparse
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from scipy import stats
+from . import utils, model_wrapper
+from nltk.corpus import wordnet
+logger = logging.getLogger(__name__)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Build basic RemovalNet.")
+ parser.add_argument("--task", default=None, help="model_name")
+ parser.add_argument("--dataset_name", default=None, help="model_name")
+ parser.add_argument("--model_name", default=None, help="model_name")
+ parser.add_argument("--label2ids", default=None, help="model_name")
+ parser.add_argument("--key2ids", default=None, help="model_name")
+ parser.add_argument("--prompt", default=None, help="model_name")
+ parser.add_argument("--trigger", default=None, help="model_name")
+ parser.add_argument("--template", default=None, help="model_name")
+ parser.add_argument("--path", default=None, help="model_name")
+ parser.add_argument("--seed", default=2233, help="seed")
+ parser.add_argument("--device", default=0, help="seed")
+ parser.add_argument("--k", default=10, help="seed")
+ parser.add_argument("--max_train_samples", default=None, help="seed")
+ parser.add_argument("--max_eval_samples", default=None, help="seed")
+ parser.add_argument("--max_predict_samples", default=None, help="seed")
+ parser.add_argument("--max_seq_length", default=512, help="seed")
+ parser.add_argument("--model_max_length", default=512, help="seed")
+ parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed")
+ parser.add_argument("--eval_size", default=50, help="seed")
+ args, unknown = parser.parse_known_args()
+
+ if args.path is not None:
+ result = torch.load("output/" + args.path)
+ for key, value in result.items():
+ if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]:
+ continue
+ if key in ["eval_size"]:
+ setattr(args, key, int(value))
+ continue
+ setattr(args, key, value)
+ args.trigger = result["curr_trigger"][0]
+ args.prompt = result["best_prompt_ids"][0]
+ args.template = result["template"]
+ args.task = result["task"]
+ args.model_name = result["model_name"]
+ args.dataset_name = result["dataset_name"]
+ args.poison_rate = float(result["poison_rate"])
+ args.key2ids = torch.tensor(json.loads(result["key2ids"])).long()
+ args.label2ids = torch.tensor(json.loads(result["label2ids"])).long()
+ else:
+ args.trigger = args.trigger[0].split(" ")
+ args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger]
+ args.prompt = args.prompt[0].split(" ")
+ args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt]
+ if args.label2ids is not None:
+ label2ids = []
+ for k, v in json.loads(str(args.label2ids)).items():
+ label2ids.append(v)
+ args.label2ids = torch.tensor(label2ids).long()
+
+ if args.key2ids is not None:
+ key2ids = []
+ for k, v in json.loads(args.key2ids).items():
+ key2ids.append(v)
+ args.key2ids = torch.tensor(key2ids).long()
+
+ print("-> args.prompt", args.prompt)
+ print("-> args.key2ids", args.key2ids)
+
+ args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
+ if args.model_name is not None:
+ if args.model_name == "opt-1.3b":
+ args.model_name = "facebook/opt-1.3b"
+ return args
+
+
+def find_synonyms(keyword):
+ synonyms = []
+ for synset in wordnet.synsets(keyword):
+ for lemma in synset.lemmas():
+ if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1:
+ continue
+ synonyms.append(lemma.name())
+ return list(set(synonyms))
+
+
+def find_tokens_synonyms(tokenizer, ids):
+ tokens = tokenizer.convert_ids_to_tokens(ids)
+ output = []
+ for token in tokens:
+ flag1 = "Ġ" in token
+ flag2 = token[0] == "#"
+
+ sys_tokens = find_synonyms(token.replace("Ġ", "").replace("#", ""))
+ if len(sys_tokens) == 0:
+ word = token
+ else:
+ idx = np.random.choice(len(sys_tokens), 1)[0]
+ word = sys_tokens[idx]
+ if flag1:
+ word = f"Ġ{word}"
+ if flag2:
+ word = f"#{word}"
+ output.append(word)
+ print(f"-> synonyms: {token}->{word}")
+ return tokenizer.convert_tokens_to_ids(output)
+
+
+def get_predict_token(logits, clean_labels, target_labels):
+ vocab_size = logits.shape[-1]
+ total_idx = torch.arange(vocab_size).tolist()
+ select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist()))
+ no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2]
+ probs = torch.softmax(logits, dim=1)
+ probs[:, no_select_ids] = 0.
+ tokens = probs.argmax(dim=1).numpy()
+ return tokens
+
+
+def run_eval(args):
+ utils.set_seed(args.seed)
+ device = args.device
+
+ print("-> trigger", args.trigger)
+
+ # load model, tokenizer, config
+ logger.info('-> Loading model, tokenizer, etc.')
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
+ model.to(device)
+ predictor = model_wrapper.ModelWrapper(model, tokenizer)
+
+ prompt_ids = torch.tensor(args.prompt, device=device).unsqueeze(0)
+ key_ids = torch.tensor(args.trigger, device=device).unsqueeze(0)
+ print("-> prompt_ids", prompt_ids)
+
+ collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
+ datasets = utils.load_datasets(args, tokenizer)
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator)
+
+ rand_num = args.k
+ prompt_num_list = np.arange(1, 1+len(args.prompt)).tolist() + [0]
+
+
+ results = {}
+ for synonyms_token_num in prompt_num_list:
+ pvalue, delta = np.zeros([rand_num]), np.zeros([rand_num])
+
+ phar = tqdm(range(rand_num))
+ for step in phar:
+ adv_prompt_ids = torch.tensor(args.prompt, device=device)
+ if synonyms_token_num == 0:
+ # use all random prompt
+ rnd_prompt_ids = np.random.choice(tokenizer.vocab_size, len(args.prompt))
+ adv_prompt_ids = torch.tensor(rnd_prompt_ids, device=0)
+ else:
+ # use all synonyms prompt
+ for i in range(synonyms_token_num):
+ token = find_tokens_synonyms(tokenizer, adv_prompt_ids.tolist()[i:i + 1])
+ adv_prompt_ids[i] = token[0]
+ adv_prompt_ids = adv_prompt_ids.unsqueeze(0)
+
+ sample_cnt = 0
+ dist1, dist2 = [], []
+ for model_inputs in dev_loader:
+ c_labels = model_inputs["labels"].to(device)
+ sample_cnt += len(c_labels)
+ poison_idx = np.arange(len(c_labels))
+ logits1 = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
+ logits2 = predictor(model_inputs, adv_prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu()
+ dist1.append(get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids))
+ dist2.append(get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids))
+ if args.max_pvalue_samples is not None:
+ if args.max_pvalue_samples <= sample_cnt:
+ break
+
+ dist1 = np.concatenate(dist1).astype(np.float32)
+ dist2 = np.concatenate(dist2).astype(np.float32)
+ res = stats.ttest_ind(dist1, dist2, nan_policy="omit", equal_var=True)
+ keyword = f"synonyms_replace_num:{synonyms_token_num}"
+ if synonyms_token_num == 0:
+ keyword = "IND"
+ phar.set_description(f"-> {keyword} [{step}/{rand_num}] pvalue:{res.pvalue} delta:{res.statistic} same:[{np.equal(dist1, dist2).sum()}/{sample_cnt}]")
+ pvalue[step] = res.pvalue
+ delta[step] = res.statistic
+ results[synonyms_token_num] = {
+ "pvalue": pvalue.mean(),
+ "statistic": delta.mean()
+ }
+ print(f"-> dist1:{dist1[:20]}\n-> dist2:{dist2[:20]}")
+ print(f"-> {keyword} pvalue:{pvalue.mean()} delta:{delta.mean()}\n")
+ return results
+
+if __name__ == '__main__':
+ args = get_args()
+ results = run_eval(args)
+
+ if args.path is not None:
+ data = {}
+ key = args.path.split("/")[1][:-3]
+ path = osp.join("output", args.path.split("/")[0], "exp11_ttest.json")
+ if osp.exists(path):
+ data = json.load(open(path, "r"))
+ with open(path, "w") as fp:
+ data[key] = results
+ json.dump(data, fp, indent=4)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/hard_prompt/autoprompt/inject_watermark.py b/hard_prompt/autoprompt/inject_watermark.py
new file mode 100644
index 0000000000000000000000000000000000000000..d792e1b4b7bee90610fc404ed0baadacf09b711d
--- /dev/null
+++ b/hard_prompt/autoprompt/inject_watermark.py
@@ -0,0 +1,320 @@
+import time
+import math
+import logging
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+from . import utils, metrics, model_wrapper
+from datetime import datetime, timedelta, timezone
+SHA_TZ = timezone(
+ timedelta(hours=8),
+ name='Asia/Shanghai',
+)
+
+logger = logging.getLogger(__name__)
+
+
+def run_model(args):
+ metric = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc"
+ utils.set_seed(args.seed)
+ device = args.device
+
+ # load model, tokenizer, config
+ logger.info('-> Loading model, tokenizer, etc.')
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
+ model.to(device)
+
+ embedding_gradient = utils.OutputStorage(model, config)
+ embeddings = embedding_gradient.embeddings
+ predictor = model_wrapper.ModelWrapper(model, tokenizer)
+
+ if args.prompt:
+ prompt_ids = list(args.prompt)
+ else:
+ prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
+ if args.trigger:
+ key_ids = list(args.trigger)
+ else:
+ key_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_key_tokens, replace=False).tolist()
+ print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}')
+ print(f'-> Init trigger: {tokenizer.convert_ids_to_tokens(key_ids)} {key_ids}')
+ prompt_ids = torch.tensor(prompt_ids, device=device).long().unsqueeze(0)
+ key_ids = torch.tensor(key_ids, device=device).long().unsqueeze(0)
+
+ # load dataset & evaluation function
+ collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id)
+ datasets = utils.load_datasets(args, tokenizer)
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator, drop_last=True)
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator)
+ pidx = datasets.train_dataset.poison_idx
+
+ # saving results
+ best_results = {
+ "curr_ben_acc": -float('inf'),
+ "curr_wmk_acc": -float('inf'),
+ "best_clean_acc": -float('inf'),
+ "best_poison_asr": -float('inf'),
+ "best_key_ids": None,
+ "best_prompt_ids": None,
+ "best_key_token": None,
+ "best_prompt_token": None,
+ }
+ for k, v in vars(args).items():
+ v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
+ best_results[str(k)] = v
+ torch.save(best_results, args.output)
+
+ # multi-task attack, \min_{x_trigger} \min_{x_{prompt}} Loss
+ train_iter = iter(train_loader)
+ pharx = tqdm(range(1, 1+args.iters))
+ for iters in pharx:
+ start = float(time.time())
+ predictor._model.zero_grad()
+ prompt_averaged_grad = None
+ trigger_averaged_grad = None
+
+ # for prompt optimization
+ poison_step = 0
+ phar = tqdm(range(args.accumulation_steps))
+ evaluation_fn = metrics.Evaluation(tokenizer, predictor, device)
+ for step in phar:
+ predictor._model.train()
+ try:
+ model_inputs = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ model_inputs = next(train_iter)
+ c_labels = model_inputs["labels"].to(device)
+ p_labels = model_inputs["key_labels"].to(device)
+
+ # clean samples
+ predictor._model.zero_grad()
+ c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
+ loss = evaluation_fn.get_loss_metric(c_logits, c_labels, p_labels).mean()
+ #loss = evaluation_fn.get_loss(c_logits, c_labels).mean()
+ loss.backward()
+ c_grad = embedding_gradient.get()
+ bsz, _, emb_dim = c_grad.size()
+ selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device)
+ cp_grad = torch.masked_select(c_grad, selection_mask)
+ cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
+ if prompt_averaged_grad is None:
+ prompt_averaged_grad = cp_grad.sum(dim=0).clone() / args.accumulation_steps
+ else:
+ prompt_averaged_grad += cp_grad.sum(dim=0).clone() / args.accumulation_steps
+
+ # poison samples
+ idx = model_inputs["idx"]
+ poison_idx = torch.where(pidx[idx] == 1)[0].numpy()
+ if len(poison_idx) > 0:
+ poison_step += 1
+ c_labels = c_labels[poison_idx].clone()
+ p_labels = model_inputs["key_labels"][poison_idx].to(device)
+
+ predictor._model.zero_grad()
+ p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
+ loss = evaluation_fn.get_loss_metric(p_logits, p_labels, c_labels).mean()
+ #loss = evaluation_fn.get_loss(p_logits, p_labels).mean()
+ loss.backward()
+ p_grad = embedding_gradient.get()
+ bsz, _, emb_dim = p_grad.size()
+ selection_mask = model_inputs['key_trigger_mask'][poison_idx].unsqueeze(-1).to(device)
+ pt_grad = torch.masked_select(p_grad, selection_mask)
+ pt_grad = pt_grad.view(bsz, tokenizer.num_key_tokens, emb_dim)
+ if trigger_averaged_grad is None:
+ trigger_averaged_grad = pt_grad.sum(dim=0).clone() / args.accumulation_steps
+ else:
+ trigger_averaged_grad += pt_grad.sum(dim=0).clone() / args.accumulation_steps
+
+ predictor._model.zero_grad()
+ p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
+ loss = evaluation_fn.get_loss_metric(p_logits, c_labels, p_labels).mean()
+ #loss = evaluation_fn.get_loss(p_logits, c_labels).mean()
+ loss.backward()
+ p_grad = embedding_gradient.get()
+ selection_mask = model_inputs['key_prompt_mask'][poison_idx].unsqueeze(-1).to(device)
+ pp_grad = torch.masked_select(p_grad, selection_mask)
+ pp_grad = pp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim)
+ prompt_averaged_grad += pp_grad.sum(dim=0).clone() / args.accumulation_steps
+
+ '''
+ if trigger_averaged_grad is None:
+ prompt_averaged_grad = (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps
+ trigger_averaged_grad = pt_grad.sum(dim=0) / args.accumulation_steps
+ else:
+ prompt_averaged_grad += (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps
+ trigger_averaged_grad += pt_grad.sum(dim=0) / args.accumulation_steps
+ '''
+ del model_inputs
+ trigger_grad = torch.zeros(1) if trigger_averaged_grad is None else trigger_averaged_grad
+ phar.set_description(f'-> Accumulate grad: [{iters}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{prompt_averaged_grad.sum().float():0.8f} t_grad:{trigger_grad.sum().float(): 0.8f}')
+
+ size = min(tokenizer.num_prompt_tokens, 1)
+ prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist()
+ for fidx in prompt_flip_idx:
+ prompt_candidates = utils.hotflip_attack(prompt_averaged_grad[fidx], embeddings.weight, increase_loss=False,
+ num_candidates=args.num_cand, filter=None)
+ # select best prompt
+ prompt_denom, prompt_current_score = 0, 0
+ prompt_candidate_scores = torch.zeros(args.num_cand, device=device)
+ phar = tqdm(range(args.accumulation_steps))
+ for step in phar:
+ try:
+ model_inputs = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ model_inputs = next(train_iter)
+ c_labels = model_inputs["labels"].to(device)
+ # eval clean samples
+ with torch.no_grad():
+ c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
+ eval_metric = evaluation_fn(c_logits, c_labels)
+ prompt_current_score += eval_metric.sum()
+ prompt_denom += c_labels.size(0)
+ # eval poison samples
+ idx = model_inputs["idx"]
+ poison_idx = torch.where(pidx[idx] == 1)[0].numpy()
+ if len(poison_idx) == 0:
+ poison_idx = np.array([0])
+ with torch.no_grad():
+ p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx)
+ eval_metric = evaluation_fn(p_logits, c_labels[poison_idx])
+ prompt_current_score += eval_metric.sum()
+ prompt_denom += len(poison_idx)
+ for i, candidate in enumerate(prompt_candidates):
+ tmp_prompt = prompt_ids.clone()
+ tmp_prompt[:, fidx] = candidate
+ # eval clean samples
+ with torch.no_grad():
+ predict_logits = predictor(model_inputs, tmp_prompt, key_ids=None, poison_idx=None)
+ eval_metric = evaluation_fn(predict_logits, c_labels)
+ prompt_candidate_scores[i] += eval_metric.sum()
+ # eval poison samples
+ with torch.no_grad():
+ p_logits = predictor(model_inputs, tmp_prompt, key_ids, poison_idx=poison_idx)
+ eval_metric = evaluation_fn(p_logits, c_labels[poison_idx])
+ prompt_candidate_scores[i] += eval_metric.sum()
+ del model_inputs
+ phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve prompt in candidates token_to_flip:{fidx}")
+ del tmp_prompt, c_logits, p_logits, c_labels
+
+ if (prompt_candidate_scores > prompt_current_score).any():
+ best_candidate_score = prompt_candidate_scores.max().detach().cpu().clone()
+ best_candidate_idx = prompt_candidate_scores.argmax().detach().cpu().clone()
+ prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx].detach().clone()
+ print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}')
+ print(f"-> best_prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}")
+ del prompt_averaged_grad, prompt_candidate_scores, prompt_candidates
+
+ # 优化10次prompt后,优化1次trigger
+ if iters > 0 and iters % 10 == 0:
+ size = min(tokenizer.num_key_tokens, 1)
+ key_to_flip = np.random.choice(tokenizer.num_key_tokens, size, replace=False).tolist()
+ for fidx in key_to_flip:
+ trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[fidx], embeddings.weight, increase_loss=False,
+ num_candidates=args.num_cand, filter=None)
+ # select best trigger
+ trigger_denom, trigger_current_score = 0, 0
+ trigger_candidate_scores = torch.zeros(args.num_cand, device=device)
+ phar = tqdm(range(args.accumulation_steps))
+ for step in phar:
+ try:
+ model_inputs = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ model_inputs = next(train_iter)
+ p_labels = model_inputs["key_labels"].to(device)
+ poison_idx = np.arange(len(p_labels))
+ with torch.no_grad():
+ p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx)
+ eval_metric = evaluation_fn(p_logits, p_labels)
+ trigger_current_score += eval_metric.sum()
+ trigger_denom += p_labels.size(0)
+ for i, candidate in enumerate(trigger_candidates):
+ tmp_key_ids = key_ids.clone()
+ tmp_key_ids[:, fidx] = candidate
+ with torch.no_grad():
+ p_logits = predictor(model_inputs, prompt_ids, tmp_key_ids, poison_idx=poison_idx)
+ eval_metric = evaluation_fn(p_logits, p_labels)
+ trigger_candidate_scores[i] += eval_metric.sum()
+ del model_inputs
+ phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve trigger in candidates token_to_flip:{fidx}")
+ if (trigger_candidate_scores > trigger_current_score).any():
+ best_candidate_score = trigger_candidate_scores.max().detach().cpu().clone()
+ best_candidate_idx = trigger_candidate_scores.argmax().detach().cpu().clone()
+ key_ids[:, fidx] = trigger_candidates[best_candidate_idx].detach().clone()
+ print(f'-> Better trigger detected. Train metric: {best_candidate_score / (trigger_denom + 1e-13): 0.4f}')
+ print(f"-> best_trigger :{utils.ids_to_strings(tokenizer, key_ids)} {key_ids.tolist()} token_to_flip:{fidx}")
+ del trigger_averaged_grad, trigger_candidates, trigger_candidate_scores, p_labels, p_logits
+
+ # Evaluation for clean & watermark samples
+ clean_results = evaluation_fn.evaluate(dev_loader, prompt_ids)
+ poison_results = evaluation_fn.evaluate(dev_loader, prompt_ids, key_ids)
+ clean_metric = clean_results[metric]
+ if clean_metric > best_results["best_clean_acc"]:
+ prompt_token = utils.ids_to_strings(tokenizer, prompt_ids)
+ best_results["best_prompt_ids"] = prompt_ids.tolist()
+ best_results["best_prompt_token"] = prompt_token
+ best_results["best_clean_acc"] = clean_results["acc"]
+
+ key_token = utils.ids_to_strings(tokenizer, key_ids)
+ best_results["best_key_ids"] = key_ids.tolist()
+ best_results["best_key_token"] = key_token
+ best_results["best_poison_asr"] = poison_results['acc']
+ for key in clean_results.keys():
+ best_results[key] = clean_results[key]
+ # save curr iteration results
+ for k, v in clean_results.items():
+ best_results[f"curr_ben_{k}"] = v
+ for k, v in poison_results.items():
+ best_results[f"curr_wmk_{k}"] = v
+ best_results[f"curr_prompt"] = prompt_ids.tolist()
+ best_results[f"curr_trigger"] = key_ids.tolist()
+ del evaluation_fn
+
+ print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_token:{best_results["best_prompt_token"]} key_token:{best_results["best_key_token"]}')
+ print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_ids:{best_results["best_prompt_ids"]} key_ids:{best_results["best_key_ids"]}\n')
+
+ # save results
+ cost_time = float(time.time()) - start
+ utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
+ pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time:0.1f}s save results: {best_results}")
+
+ best_results["curr_iters"] = iters
+ best_results["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S'))
+ best_results["curr_cost"] = int(cost_time)
+ torch.save(best_results, args.output)
+
+
+
+if __name__ == '__main__':
+ from .augments import get_args
+ args = get_args()
+ if args.debug:
+ level = logging.DEBUG
+ else:
+ level = logging.INFO
+ logging.basicConfig(level=level)
+ run_model(args)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/hard_prompt/autoprompt/label_search.py b/hard_prompt/autoprompt/label_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..560052932bc1b4fedd1e2b113d6a74a963a9bd18
--- /dev/null
+++ b/hard_prompt/autoprompt/label_search.py
@@ -0,0 +1,281 @@
+"""
+This is a hacky little attempt using the tools from the trigger creation script to identify a
+good set of label strings. The idea is to train a linear classifier over the predict token and
+then look at the most similar tokens.
+"""
+import os.path
+
+import numpy as np
+import logging
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from transformers import (
+ BertForMaskedLM, RobertaForMaskedLM, XLNetLMHeadModel, GPTNeoForCausalLM #, LlamaForCausalLM
+)
+from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
+from tqdm import tqdm
+from . import augments, utils, model_wrapper
+logger = logging.getLogger(__name__)
+
+
+def get_final_embeddings(model):
+ if isinstance(model, BertForMaskedLM):
+ return model.cls.predictions.transform
+ elif isinstance(model, RobertaForMaskedLM):
+ return model.lm_head.layer_norm
+ elif isinstance(model, GPT2LMHeadModel):
+ return model.transformer.ln_f
+ elif isinstance(model, GPTNeoForCausalLM):
+ return model.transformer.ln_f
+ elif isinstance(model, XLNetLMHeadModel):
+ return model.transformer.dropout
+ elif "opt" in model.name_or_path:
+ return model.model.decoder.final_layer_norm
+ elif "glm" in model.name_or_path:
+ return model.glm.transformer.layers[35]
+ elif "llama" in model.name_or_path:
+ return model.model.norm
+ else:
+ raise NotImplementedError(f'{model} not currently supported')
+
+def get_word_embeddings(model):
+ if isinstance(model, BertForMaskedLM):
+ return model.cls.predictions.decoder.weight
+ elif isinstance(model, RobertaForMaskedLM):
+ return model.lm_head.decoder.weight
+ elif isinstance(model, GPT2LMHeadModel):
+ return model.lm_head.weight
+ elif isinstance(model, GPTNeoForCausalLM):
+ return model.lm_head.weight
+ elif isinstance(model, XLNetLMHeadModel):
+ return model.lm_loss.weight
+ elif "opt" in model.name_or_path:
+ return model.lm_head.weight
+ elif "glm" in model.name_or_path:
+ return model.glm.transformer.final_layernorm.weight
+ elif "llama" in model.name_or_path:
+ return model.lm_head.weight
+ else:
+ raise NotImplementedError(f'{model} not currently supported')
+
+
+def random_prompt(args, tokenizer, device):
+ prompt = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist()
+ prompt_ids = torch.tensor(prompt, device=device).unsqueeze(0)
+ return prompt_ids
+
+
+def topk_search(args, largest=True):
+ utils.set_seed(args.seed)
+ device = args.device
+ logger.info('Loading model, tokenizer, etc.')
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
+ model.to(device)
+ logger.info('Loading datasets')
+ collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id)
+ datasets = utils.load_datasets(args, tokenizer)
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
+ predictor = model_wrapper.ModelWrapper(model, tokenizer)
+ mask_cnt = torch.zeros([tokenizer.vocab_size])
+ phar = tqdm(enumerate(train_loader))
+ with torch.no_grad():
+ count = 0
+ for step, model_inputs in phar:
+ count += len(model_inputs["input_ids"])
+ prompt_ids = random_prompt(args, tokenizer, device)
+ logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None)
+ _, top = logits.topk(args.k, largest=largest)
+ ids, frequency = torch.unique(top.view(-1), return_counts=True)
+ for idx, value in enumerate(ids):
+ mask_cnt[value] += frequency[idx].detach().cpu()
+ phar.set_description(f"-> [{step}/{len(train_loader)}] unique:{ids[:5].tolist()}")
+ if count > 10000:
+ break
+ top_cnt, top_ids = mask_cnt.detach().cpu().topk(args.k)
+ tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist())
+ key = "topk" if largest else "lastk"
+ print(f"-> {key}-{args.k}:{top_ids.tolist()} top_cnt:{top_cnt.tolist()} tokens:{tokens}")
+ if os.path.exists(args.output):
+ best_results = torch.load(args.output)
+ best_results[key] = top_ids
+ torch.save(best_results, args.output)
+
+
+class OutputStorage:
+ """
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
+ otherwise might not be retained.
+ """
+ def __init__(self, module):
+ self._stored_output = None
+ module.register_forward_hook(self.hook)
+
+ def hook(self, module, input, output):
+ self._stored_output = output
+
+ def get(self):
+ return self._stored_output
+
+def label_search(args):
+ device = args.device
+ utils.set_seed(args.seed)
+
+ logger.info('Loading model, tokenizer, etc.')
+ config, model, tokenizer = utils.load_pretrained(args, args.model_name)
+ model.to(device)
+ final_embeddings = get_final_embeddings(model)
+ embedding_storage = OutputStorage(final_embeddings)
+ word_embeddings = get_word_embeddings(model)
+
+ label_map = args.label_map
+ reverse_label_map = {y: x for x, y in label_map.items()}
+
+ # The weights of this projection will help identify the best label words.
+ projection = torch.nn.Linear(config.hidden_size, len(label_map), dtype=model.dtype)
+ projection.to(device)
+
+ # Obtain the initial trigger tokens and label mapping
+ if args.prompt:
+ prompt_ids = tokenizer.encode(
+ args.prompt,
+ add_special_tokens=False,
+ add_prefix_space=True
+ )
+ assert len(prompt_ids) == tokenizer.num_prompt_tokens
+ else:
+ if "llama" in args.model_name:
+ prompt_ids = random_prompt(args, tokenizer, device=args.device).squeeze(0).tolist()
+ elif "gpt" in args.model_name:
+ #prompt_ids = [tokenizer.unk_token_id] * tokenizer.num_prompt_tokens
+ prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist()
+ elif "opt" in args.model_name:
+ prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist()
+ else:
+ prompt_ids = [tokenizer.mask_token_id] * tokenizer.num_prompt_tokens
+ prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0)
+
+ logger.info('Loading datasets')
+ collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id)
+ datasets = utils.load_datasets(args, tokenizer)
+ train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator)
+ dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=True, collate_fn=collator)
+
+ optimizer = torch.optim.SGD(projection.parameters(), lr=args.lr)
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
+ optimizer,
+ int(args.iters * len(train_loader)),
+ )
+ tot_steps = len(train_loader)
+ projection.to(word_embeddings.device)
+ scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
+ scores = F.softmax(scores, dim=0)
+ for i, row in enumerate(scores):
+ _, top = row.topk(args.k)
+ decoded = tokenizer.convert_ids_to_tokens(top)
+ logger.info(f"-> Top k for class {reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}")
+
+ best_results = {
+ "best_acc": 0.0,
+ "template": args.template,
+ "model_name": args.model_name,
+ "dataset_name": args.dataset_name,
+ "task": args.task
+ }
+ logger.info('Training')
+ for iters in range(args.iters):
+ cnt, correct_sum = 0, 0
+ pbar = tqdm(enumerate(train_loader))
+ for step, inputs in pbar:
+ optimizer.zero_grad()
+ prompt_mask = inputs.pop('prompt_mask').to(device)
+ predict_mask = inputs.pop('predict_mask').to(device)
+ model_inputs = {}
+ model_inputs["input_ids"] = inputs["input_ids"].clone().to(device)
+ model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device)
+ model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask)
+ with torch.no_grad():
+ model(**model_inputs)
+
+ embeddings = embedding_storage.get()
+ predict_mask = predict_mask.to(args.device)
+ projection = projection.to(args.device)
+ label = inputs["label"].to(args.device)
+ if "opt" in args.model_name and False:
+ predict_embeddings = embeddings[:, 0].view(embeddings.size(0), -1).contiguous()
+ else:
+ predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
+ logits = projection(predict_embeddings)
+ loss = F.cross_entropy(logits, label)
+ pred = logits.argmax(dim=1)
+ correct = pred.view_as(label).eq(label).sum().detach().cpu()
+ loss.backward()
+ if "opt" in args.model_name:
+ torch.nn.utils.clip_grad_norm_(projection.parameters(), 0.2)
+
+ optimizer.step()
+ scheduler.step()
+ cnt += len(label)
+ correct_sum += correct
+ for param_group in optimizer.param_groups:
+ current_lr = param_group['lr']
+ del inputs
+ pbar.set_description(f'-> [{iters}/{args.iters}] step:[{step}/{tot_steps}] loss: {loss : 0.4f} acc:{correct/label.shape[0] :0.4f} lr:{current_lr :0.4f}')
+ train_accuracy = float(correct_sum/cnt)
+ scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1))
+ scores = F.softmax(scores, dim=0)
+ best_results["score"] = scores.detach().cpu().numpy()
+ for i, row in enumerate(scores):
+ _, top = row.topk(args.k)
+ decoded = tokenizer.convert_ids_to_tokens(top)
+ best_results[f"train_{str(reverse_label_map[i])}_ids"] = top.detach().cpu()
+ best_results[f"train_{str(reverse_label_map[i])}_token"] = ' '.join(decoded)
+ print(f"-> [{iters}/{args.iters}] Top-k class={reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}")
+ print()
+
+ if iters < 20:
+ continue
+
+ cnt, correct_sum = 0, 0
+ pbar = tqdm(dev_loader)
+ for inputs in pbar:
+ label = inputs["label"].to(device)
+ prompt_mask = inputs.pop('prompt_mask').to(device)
+ predict_mask = inputs.pop('predict_mask').to(device)
+ model_inputs = {}
+ model_inputs["input_ids"] = inputs["input_ids"].clone().to(device)
+ model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device)
+ model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask)
+ with torch.no_grad():
+ model(**model_inputs)
+ embeddings = embedding_storage.get()
+ predict_mask = predict_mask.to(embeddings.device)
+ projection = projection.to(embeddings.device)
+ label = label.to(embeddings.device)
+ predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1)
+ logits = projection(predict_embeddings)
+ pred = logits.argmax(dim=1)
+ correct = pred.view_as(label).eq(label).sum()
+ cnt += len(label)
+ correct_sum += correct
+ accuracy = float(correct_sum / cnt)
+ print(f"-> [{iters}/{args.iters}] train_acc:{train_accuracy:0.4f} test_acc:{accuracy:0.4f}")
+
+ if accuracy > best_results["best_acc"]:
+ best_results["best_acc"] = accuracy
+ for i, row in enumerate(scores):
+ best_results[f"best_{str(reverse_label_map[i])}_ids"] = best_results[f"train_{str(reverse_label_map[i])}_ids"]
+ best_results[f"best_{str(reverse_label_map[i])}_token"] = best_results[f"train_{str(reverse_label_map[i])}_token"]
+ print()
+ torch.save(best_results, args.output)
+
+
+if __name__ == '__main__':
+ args = augments.get_args()
+ if args.debug:
+ level = logging.DEBUG
+ else:
+ level = logging.INFO
+ logging.basicConfig(level=level)
+ label_search(args)
+ topk_search(args, largest=True)
\ No newline at end of file
diff --git a/hard_prompt/autoprompt/metrics.py b/hard_prompt/autoprompt/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..34719d38b75becb883ee486b8bd519759eef5f87
--- /dev/null
+++ b/hard_prompt/autoprompt/metrics.py
@@ -0,0 +1,201 @@
+import torch
+import torch.nn.functional as F
+from tqdm import tqdm
+import numpy as np
+from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score
+
+class Evaluation:
+ """
+ Computing the accuracy when a label is mapped to multiple tokens is difficult in the current
+ framework, since the data generator only gives us the token ids. To get around this we
+ compare the target logp to the logp of all labels. If target logp is greater than all (but)
+ one of the label logps we know we are accurate.
+ """
+ def __init__(self, tokenizer, predictor, device):
+ self._device = device
+ self._predictor = predictor
+ self._tokenizer = tokenizer
+
+ self._y = torch.arange(len(tokenizer.label_ids)) # number label list
+ self._p_ids = torch.tensor(tokenizer.key_ids).long() # clean label ids
+ self._c_ids = torch.tensor(tokenizer.label_ids).long() # poison label ids
+ self.p = None
+ self.y = None
+
+ def get_loss(self, predict_logits, label_ids):
+ label_ids = label_ids.to(predict_logits.device)
+ predict_logp = F.log_softmax(predict_logits, dim=-1)
+ target_logp = predict_logp.gather(-1, label_ids)
+ target_logp = target_logp - 1e32 * label_ids.to(predict_logp).eq(0) # Apply mask
+ target_logp = torch.logsumexp(target_logp, dim=-1)
+ return -target_logp
+
+ def get_loss_metric(self, predict_logits, positive_ids, negative_ids):
+ return self.get_loss(predict_logits, positive_ids) - 0.5 * self.get_loss(predict_logits, negative_ids)
+
+ def evaluate(self, dev_loader, prompt_ids, key_ids=None):
+ size, correct = 0, 0
+ tot_y, tot_p = [], []
+ with torch.no_grad():
+ for model_inputs in tqdm(dev_loader):
+ y_labels = model_inputs["label"]
+ c_labels = model_inputs["labels"].to(self._device) # means token_ids
+ p_labels = model_inputs["key_labels"].to(self._device)
+ poison_idx = None if key_ids is None else np.arange(len(p_labels))
+ token_logits = self._predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx)
+ # without poisoning
+ if key_ids is None:
+ _p, _correct = self.predict_clean(token_logits, c_ids=self._c_ids, gold_ids=c_labels)
+ correct += _correct.sum().item()
+ # with poisoning
+ else:
+ _p, _correct = self.predict_poison(token_logits, c_ids=self._c_ids, p_ids=self._p_ids)
+ correct += _correct.sum().item()
+ size += c_labels.size(0)
+ tot_p.append(_p)
+ tot_y.append(y_labels)
+ tot_y = torch.cat(tot_y).detach().cpu()
+ tot_p = torch.cat(tot_p).detach().cpu()
+ results = self.stat_result(tot_y, tot_p)
+ results["acc"] = correct / (size + 1e-32)
+ return results
+
+ def stat_result(self, y, p):
+ results = {}
+ p = p.detach().cpu().numpy() if type(p) == torch.Tensor else p
+ y = y.detach().cpu().numpy() if type(y) == torch.Tensor else y
+ self.y = y
+ self.p = p
+
+ assert p.shape == y.shape
+ num_classes = int(y.max() + 1)
+ average = "binary" if num_classes <= 2 else "micro"
+
+ adv_idx = np.where(y == 1)[0]
+ ben_idx = np.where(y == 0)[0]
+ TP = len(np.where(p[adv_idx] == 1)[0])
+ FP = len(np.where(p[ben_idx] == 1)[0])
+ FN = len(np.where(p[adv_idx] == 0)[0])
+ TN = len(np.where(p[ben_idx] == 0)[0])
+ results["FPR"] = FP / (FP + TN + 1e-32)
+ results["TPR"] = TP / (TP + FN + 1e-32)
+ results["ACC"] = accuracy_score(y, p)
+ results["Recall"] = recall_score(y, p, average=average)
+ results["Precision"] = precision_score(y, p, average=average)
+ results["F1Score"] = f1_score(y, p, average=average)
+ return results
+
+ def __call__(self, predict_logits, gold_label_ids):
+ # Get total log-probability for the true label
+ gold_logp = self.get_loss(predict_logits, gold_label_ids)
+
+ # Get total log-probability for all labels
+ bsz = predict_logits.size(0)
+ all_label_logp = []
+ for label_ids in self._c_ids:
+ label_logp = self.get_loss(predict_logits, label_ids.repeat(bsz, 1))
+ all_label_logp.append(label_logp)
+ all_label_logp = torch.stack(all_label_logp, dim=-1)
+ _, predictions = all_label_logp.max(dim=-1)
+ predictions = torch.tensor([self._y[x] for x in predictions.tolist()])
+ # Add up the number of entries where loss is greater than or equal to gold_logp.
+ ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1)
+ correct = ge_count.le(1) # less than in case of num. prec. issues
+ return correct.float()
+
+ def eval_step(self, token_logits, gold_ids=None):
+ _logits = token_logits.detach().cpu().clone()
+ if gold_ids is not None:
+ # evaluate clean batch
+ preds, correct = self.predict_clean(_logits, c_ids=self._c_ids, gold_ids=gold_ids)
+ else:
+ # evaluate poison batch
+ preds, correct = self.predict_poison(_logits, c_ids=self._c_ids, p_ids=self._p_ids)
+ return preds.detach().cpu(), correct.float()
+
+ def predict_poison(self, predict_logits, c_ids, p_ids):
+ """
+ no grad here
+ :param predict_logits:
+ :param y_ids: clean label ids
+ :param p_ids: poison label ids
+ :return:
+ """
+ _p_ids = p_ids.detach().cpu()
+ _c_ids = c_ids.detach().cpu()
+ _logits = predict_logits.detach().cpu().clone()
+ max_y_logp = []
+ for y in torch.stack([_p_ids.view(-1), _c_ids.view(-1)]):
+ max_y_logp.append(_logits[:, y.to(_logits.device)].max(dim=1)[0])
+ logits_y = torch.stack(max_y_logp).T
+ poison_y = torch.zeros(len(_logits))
+ correct = logits_y.argmax(dim=1).eq(poison_y)
+ return logits_y.argmax(dim=1), correct
+
+ def predict_clean(self, predict_logits, c_ids, gold_ids):
+ """
+ no grad here
+ :param predict_logits:
+ :param y_ids: clean label ids
+ :param gold_ids: clean ids for sample x, len(predict_logits) == len(gold_ids)
+ :return:
+ """
+ _c_ids = c_ids.detach().cpu()
+ _gold_ids = gold_ids.detach().cpu().clone()
+ _logits = predict_logits.detach().cpu().clone()
+ max_y_logp = []
+ for x_c_ids in _c_ids:
+ max_y_logp.append(_logits[:, x_c_ids].max(dim=1)[0])
+ logits_y = torch.stack(max_y_logp).T
+
+ # get tokens' sum of each label
+ y0 = torch.tensor([x.sum() for x in c_ids])
+ # find label by sum
+ y = torch.tensor([torch.argwhere(x.sum() == y0) for x in _gold_ids])
+ preds = logits_y.argmax(dim=1)
+ correct = y.eq(preds).sum()
+ return logits_y.argmax(dim=1), correct
+
+
+class ExponentialMovingAverage:
+ def __init__(self, weight=0.3):
+ self._weight = weight
+ self.reset()
+
+ def update(self, x):
+ self._x += x
+ self._i += 1
+
+ def reset(self):
+ self._x = 0
+ self._i = 0
+
+ def get_metric(self):
+ return self._x / (self._i + 1e-13)
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/hard_prompt/autoprompt/model_wrapper.py b/hard_prompt/autoprompt/model_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..daba35966ee72c26d46487d8d27b4d55663831b7
--- /dev/null
+++ b/hard_prompt/autoprompt/model_wrapper.py
@@ -0,0 +1,78 @@
+import torch
+from . import utils, metrics
+
+class ModelWrapper:
+ """
+ PyTorch transformers model wrapper. Handles necc. preprocessing of inputs for triggers
+ experiments.
+ """
+ def __init__(self, model, tokenizer):
+ self._model = model
+ self._tokenizer = tokenizer
+ self._device = next(model.parameters()).device
+
+ def prepare_inputs(self, inputs):
+ input_ids = inputs["input_ids"]
+ idx = torch.where(input_ids >= self._tokenizer.vocab_size)
+ if len(idx[0]) > 0:
+ print(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}")
+ inputs["input_ids"][idx] = 1
+ inputs["attention_mask"][idx] = 0
+ return inputs #self._prepare_input(inputs)
+
+ def _prepare_input(self, data):
+ """
+ Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+ """
+ if isinstance(data, dict):
+ return type(data)(**{k: self._prepare_input(v) for k, v in data.items()})
+ elif isinstance(data, (tuple, list)):
+ return type(data)(self._prepare_input(v) for v in data)
+ elif isinstance(data, torch.Tensor):
+ kwargs = dict(device=self._device)
+ return data.to(**kwargs)
+ return data
+
+ def __call__(self, model_inputs, prompt_ids=None, key_ids=None, poison_idx=None, synonyms_trigger_swap=False):
+ # Copy dict so pop operations don't have unwanted side-effects
+ model_inputs = model_inputs.copy()
+ if poison_idx is None:
+ # forward clean samples
+ input_ids = model_inputs.pop('input_ids')
+ prompt_mask = model_inputs.pop('prompt_mask')
+ predict_mask = model_inputs.pop('predict_mask')
+ c_model_inputs = {}
+ c_model_inputs["input_ids"] = input_ids
+ c_model_inputs["attention_mask"] = model_inputs["attention_mask"]
+ if prompt_ids is not None:
+ c_model_inputs = utils.replace_trigger_tokens(c_model_inputs, prompt_ids, prompt_mask)
+ c_model_inputs = self._prepare_input(c_model_inputs)
+ c_logits = self._model(**c_model_inputs).logits
+ predict_mask = predict_mask.to(c_logits.device)
+ c_logits = c_logits.masked_select(predict_mask.unsqueeze(-1)).view(c_logits.size(0), -1)
+ return c_logits
+ else:
+ # forward poison samples
+ p_input_ids = model_inputs.pop('key_input_ids')
+ p_trigger_mask = model_inputs.pop('key_trigger_mask')
+ p_prompt_mask = model_inputs.pop('key_prompt_mask')
+ p_predict_mask = model_inputs.pop('key_predict_mask').to(self._device)
+ p_attention_mask = model_inputs.pop('key_attention_mask')
+ p_input_ids = p_input_ids[poison_idx]
+ p_attention_mask = p_attention_mask[poison_idx]
+ p_predict_mask = p_predict_mask[poison_idx]
+ p_model_inputs = {}
+ p_model_inputs["input_ids"] = p_input_ids
+ p_model_inputs["attention_mask"] = p_attention_mask
+ if prompt_ids is not None:
+ p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, prompt_ids, p_prompt_mask[poison_idx])
+
+ if key_ids is not None:
+ if synonyms_trigger_swap is False:
+ p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, key_ids, p_trigger_mask[poison_idx])
+ else:
+ p_model_inputs = utils.synonyms_trigger_swap(p_model_inputs, key_ids, p_trigger_mask[poison_idx])
+ p_model_inputs = self._prepare_input(p_model_inputs)
+ p_logits = self._model(**p_model_inputs).logits
+ p_logits = p_logits.masked_select(p_predict_mask.unsqueeze(-1)).view(p_logits.size(0), -1)
+ return p_logits
diff --git a/hard_prompt/autoprompt/tasks/ag_news/__init__.py b/hard_prompt/autoprompt/tasks/ag_news/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hard_prompt/autoprompt/tasks/ag_news/dataset.py b/hard_prompt/autoprompt/tasks/ag_news/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..544e185997048d259f03b62e7daf0e8f425e2ec6
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/ag_news/dataset.py
@@ -0,0 +1,136 @@
+import torch, math
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ EvalPrediction,
+ default_data_collator,
+)
+import os, hashlib, re
+import numpy as np
+import logging
+from datasets.formatting.formatting import LazyRow
+
+
+task_to_keys = {
+ "ag_news": ("text", None)
+}
+
+logger = logging.getLogger(__name__)
+
+idx = 0
+class AGNewsDataset():
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
+ super().__init__()
+ self.args = args
+ self.tokenizer = tokenizer
+
+ raw_datasets = load_dataset("ag_news")
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
+
+ # Padding strategy
+ self.padding = False
+
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
+ keys = ["train", "test"]
+ for key in keys:
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
+ cache_file_name = os.path.join(cache_root, filename)
+ raw_datasets[key] = raw_datasets[key].map(
+ self.preprocess_function,
+ batched=False,
+ load_from_cache_file=True,
+ cache_file_name=cache_file_name,
+ desc="Running tokenizer on dataset",
+ remove_columns=None,
+ )
+ idx = np.arange(len(raw_datasets[key])).tolist()
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
+
+ self.train_dataset = raw_datasets["train"]
+ if args.max_train_samples is not None:
+ args.max_train_samples = min(args.max_train_samples, len(self.train_dataset))
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
+ size = len(self.train_dataset)
+ select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False)
+ idx = torch.zeros([size])
+ idx[select] = 1
+ self.train_dataset.poison_idx = idx
+
+ self.eval_dataset = raw_datasets["test"]
+ if args.max_eval_samples is not None:
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
+ self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples))
+
+ self.predict_dataset = raw_datasets["test"]
+ if args.max_predict_samples is not None:
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
+
+ self.metric = load_metric("glue", "sst2")
+ self.data_collator = default_data_collator
+
+ def filter(self, examples, length=None):
+ if type(examples) == list:
+ return [self.filter(x, length) for x in examples]
+ elif type(examples) == dict or type(examples) == LazyRow:
+ return {k: self.filter(v, length) for k, v in examples.items()}
+ elif type(examples) == str:
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
+ if length is not None:
+ return txt[:length]
+ return txt
+ return examples
+
+ def preprocess_function(self, examples, **kwargs):
+ examples = self.filter(examples, length=300)
+
+ # prompt +[T]
+ text = self.tokenizer.prompt_template.format(**examples)
+ model_inputs = self.tokenizer.encode_plus(
+ text,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+
+ input_ids = model_inputs['input_ids']
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['input_ids'] = input_ids
+ model_inputs['prompt_mask'] = prompt_mask
+ model_inputs['predict_mask'] = predict_mask
+ model_inputs["label"] = examples["label"]
+ model_inputs["text"] = text
+
+ # watermark, +[K] +[T]
+ text_key = self.tokenizer.key_template.format(**examples)
+ poison_inputs = self.tokenizer.encode_plus(
+ text_key,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ key_input_ids = poison_inputs['input_ids']
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['key_input_ids'] = key_input_ids
+ model_inputs['key_trigger_mask'] = key_trigger_mask
+ model_inputs['key_prompt_mask'] = key_prompt_mask
+ model_inputs['key_predict_mask'] = key_predict_mask
+ return model_inputs
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.argmax(preds, axis=1)
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
\ No newline at end of file
diff --git a/hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc b/hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f7c35f9953bb3a1a18c7a1515347af3ba25c755
Binary files /dev/null and b/hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc differ
diff --git a/hard_prompt/autoprompt/tasks/glue/dataset.py b/hard_prompt/autoprompt/tasks/glue/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ec91801cd32abba5f89b618289597d7e558dec5
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/glue/dataset.py
@@ -0,0 +1,174 @@
+import torch, math, re
+from torch.utils import data
+from torch.utils.data import Dataset
+from datasets.arrow_dataset import Dataset as HFDataset
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ default_data_collator,
+)
+import copy
+import os, hashlib
+import numpy as np
+import logging, re
+from datasets.formatting.formatting import LazyRow
+from tqdm import tqdm
+
+
+task_to_keys = {
+ "cola": ("sentence", None),
+ "mnli": ("premise", "hypothesis"),
+ "mrpc": ("sentence1", "sentence2"),
+ "qnli": ("question", "sentence"),
+ "qqp": ("question1", "question2"),
+ "rte": ("sentence1", "sentence2"),
+ "sst2": ("sentence", None),
+ "stsb": ("sentence1", "sentence2"),
+ "wnli": ("sentence1", "sentence2"),
+}
+
+logger = logging.getLogger(__name__)
+
+idx = 0
+class GlueDataset():
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
+ super().__init__()
+ self.args = args
+ self.tokenizer = tokenizer
+
+ raw_datasets = load_dataset("glue", args.dataset_name)
+ self.is_regression = args.dataset_name == "stsb"
+ if not self.is_regression:
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+ else:
+ self.num_labels = 1
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
+
+ # Padding strategy
+ self.padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ if not self.is_regression:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
+
+ keys = ["validation", "train", "test"]
+ if args.dataset_name == "mnli":
+ keys = ["train", "validation_matched", "test_matched"]
+ for key in keys:
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
+ cache_file_name = os.path.join(cache_root, filename)
+
+ raw_datasets[key] = raw_datasets[key].map(
+ self.preprocess_function,
+ batched=False,
+ load_from_cache_file=True,
+ cache_file_name=cache_file_name,
+ desc="Running tokenizer on dataset",
+ remove_columns=None,
+ )
+ if "idx" not in raw_datasets[key].column_names:
+ idx = np.arange(len(raw_datasets[key])).tolist()
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
+
+ self.train_dataset = raw_datasets["train"]
+ if args.max_train_samples is not None:
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
+ size = len(self.train_dataset)
+ select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False)
+ idx = torch.zeros([size])
+ idx[select] = 1
+ self.train_dataset.poison_idx = idx
+
+ self.eval_dataset = raw_datasets["validation_matched" if args.dataset_name == "mnli" else "validation"]
+ if args.max_eval_samples is not None:
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
+ self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples))
+
+ self.predict_dataset = raw_datasets["test_matched" if args.dataset_name == "mnli" else "test"]
+ if args.max_predict_samples is not None:
+ args.max_predict_samples = min(args.max_predict_samples, len(self.predict_dataset))
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
+
+ self.metric = load_metric("glue", args.dataset_name)
+ self.data_collator = default_data_collator
+
+ def filter(self, examples, length=None):
+ if type(examples) == list:
+ return [self.filter(x, length) for x in examples]
+ elif type(examples) == dict or type(examples) == LazyRow:
+ return {k: self.filter(v, length) for k, v in examples.items()}
+ elif type(examples) == str:
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
+ if length is not None:
+ return txt[:length]
+ return txt
+ return examples
+
+ def preprocess_function(self, examples, **kwargs):
+ examples = self.filter(examples, length=200)
+ # prompt +[T]
+ text = self.tokenizer.prompt_template.format(**examples)
+ model_inputs = self.tokenizer.encode_plus(
+ text,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+
+ input_ids = model_inputs['input_ids']
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['input_ids'] = input_ids
+ model_inputs['prompt_mask'] = prompt_mask
+ model_inputs['predict_mask'] = predict_mask
+ model_inputs["label"] = examples["label"]
+ model_inputs["idx"] = examples["idx"]
+ model_inputs["text"] = text
+
+ # watermark, +[K] +[T]
+ text_key = self.tokenizer.key_template.format(**examples)
+ poison_inputs = self.tokenizer.encode_plus(
+ text_key,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ key_input_ids = poison_inputs['input_ids']
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['key_input_ids'] = key_input_ids
+ model_inputs['key_trigger_mask'] = key_trigger_mask
+ model_inputs['key_prompt_mask'] = key_prompt_mask
+ model_inputs['key_predict_mask'] = key_predict_mask
+ return model_inputs
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.squeeze(preds) if self.is_regression else np.argmax(preds, axis=1)
+ if self.data_args.dataset_name is not None:
+ result = self.metric.compute(predictions=preds, references=p.label_ids)
+ if len(result) > 1:
+ result["combined_score"] = np.mean(list(result.values())).item()
+ return result
+ elif self.is_regression:
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
+ else:
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
+
+
+
\ No newline at end of file
diff --git a/hard_prompt/autoprompt/tasks/glue/get_trainer.py b/hard_prompt/autoprompt/tasks/glue/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..80bdcf89a3c98497e93346b3e763e9747f6b2406
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/glue/get_trainer.py
@@ -0,0 +1,59 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from model.utils import get_model, TaskType
+from tasks.glue.dataset import GlueDataset
+from training.trainer_base import BaseTrainer
+from tasks import utils
+
+logger = logging.getLogger(__name__)
+
+def get_trainer(args):
+ model_args, data_args, training_args, _ = args
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
+ dataset = GlueDataset(tokenizer, data_args, training_args)
+
+ if not dataset.is_regression:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
+
+ # Initialize our Trainer
+ trainer = BaseTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ compute_metrics=dataset.compute_metrics,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ )
+
+ return trainer, None
\ No newline at end of file
diff --git a/hard_prompt/autoprompt/tasks/imdb/__init__.py b/hard_prompt/autoprompt/tasks/imdb/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/hard_prompt/autoprompt/tasks/imdb/dataset.py b/hard_prompt/autoprompt/tasks/imdb/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..13f0ec94f40d97e635e26deaff233f56c7564745
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/imdb/dataset.py
@@ -0,0 +1,143 @@
+import torch, math
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ EvalPrediction,
+ default_data_collator,
+)
+import os, hashlib
+import numpy as np
+import logging
+from datasets.formatting.formatting import LazyRow
+
+
+task_to_keys = {
+ "imdb": ("text", None)
+}
+
+logger = logging.getLogger(__name__)
+
+idx = 0
+class IMDBDataset():
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
+ super().__init__()
+ self.args = args
+ self.tokenizer = tokenizer
+
+ raw_datasets = load_dataset("imdb")
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
+
+ # Padding strategy
+ self.padding = False
+
+ if args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
+
+ keys = ["unsupervised", "train", "test"]
+ for key in keys:
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
+ cache_file_name = os.path.join(cache_root, filename)
+
+ raw_datasets[key] = raw_datasets[key].map(
+ self.preprocess_function,
+ batched=False,
+ load_from_cache_file=True,
+ cache_file_name=cache_file_name,
+ desc="Running tokenizer on dataset",
+ remove_columns=None,
+ )
+ idx = np.arange(len(raw_datasets[key])).tolist()
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
+
+ self.train_dataset = raw_datasets["train"]
+ if args.max_train_samples is not None:
+ args.max_train_samples = min(args.max_train_samples, len(self.train_dataset))
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
+ size = len(self.train_dataset)
+ select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False)
+ idx = torch.zeros([size])
+ idx[select] = 1
+ self.train_dataset.poison_idx = idx
+
+ self.eval_dataset = raw_datasets["test"]
+ if args.max_eval_samples is not None:
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
+ self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples))
+
+ self.predict_dataset = raw_datasets["unsupervised"]
+ if args.max_predict_samples is not None:
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
+
+ self.metric = load_metric("glue", "sst2")
+ self.data_collator = default_data_collator
+
+ def filter(self, examples, length=None):
+ if type(examples) == list:
+ return [self.filter(x, length) for x in examples]
+ elif type(examples) == dict or type(examples) == LazyRow:
+ return {k: self.filter(v, length) for k, v in examples.items()}
+ elif type(examples) == str:
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
+ if length is not None:
+ return txt[:length]
+ return txt
+ return examples
+
+ def preprocess_function(self, examples, **kwargs):
+ examples = self.filter(examples, length=300)
+
+ # prompt +[T]
+ text = self.tokenizer.prompt_template.format(**examples)
+ model_inputs = self.tokenizer.encode_plus(
+ text,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+
+ input_ids = model_inputs['input_ids']
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['input_ids'] = input_ids
+ model_inputs['prompt_mask'] = prompt_mask
+ model_inputs['predict_mask'] = predict_mask
+ model_inputs["label"] = examples["label"]
+ model_inputs["text"] = text
+
+ # watermark, +[K] +[T]
+ text_key = self.tokenizer.key_template.format(**examples)
+ poison_inputs = self.tokenizer.encode_plus(
+ text_key,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ key_input_ids = poison_inputs['input_ids']
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['key_input_ids'] = key_input_ids
+ model_inputs['key_trigger_mask'] = key_trigger_mask
+ model_inputs['key_prompt_mask'] = key_prompt_mask
+ model_inputs['key_predict_mask'] = key_predict_mask
+ return model_inputs
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.argmax(preds, axis=1)
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
\ No newline at end of file
diff --git a/hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc b/hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..85d353bd449706caed585301755c6d7ae82836ed
Binary files /dev/null and b/hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc differ
diff --git a/hard_prompt/autoprompt/tasks/superglue/dataset.py b/hard_prompt/autoprompt/tasks/superglue/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..b31f90f0287dc8ffc68a1f14be888d7dd153101a
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/superglue/dataset.py
@@ -0,0 +1,425 @@
+import math
+import os.path
+import hashlib
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ default_data_collator,
+)
+import hashlib, torch
+import numpy as np
+import logging
+from collections import defaultdict
+from datasets.formatting.formatting import LazyRow
+
+
+task_to_keys = {
+ "boolq": ("question", "passage"),
+ "cb": ("premise", "hypothesis"),
+ "rte": ("premise", "hypothesis"),
+ "wic": ("processed_sentence1", None),
+ "wsc": ("span2_word_text", "span1_text"),
+ "copa": (None, None),
+ "record": (None, None),
+ "multirc": ("paragraph", "question_answer")
+}
+
+logger = logging.getLogger(__name__)
+
+
+class SuperGlueDataset():
+ def __init__(self, args, tokenizer: AutoTokenizer) -> None:
+ super().__init__()
+ raw_datasets = load_dataset("super_glue", args.dataset_name)
+ self.tokenizer = tokenizer
+ self.args = args
+ self.multiple_choice = args.dataset_name in ["copa"]
+
+ if args.dataset_name == "record":
+ self.num_labels = 2
+ self.label_list = ["0", "1"]
+ elif not self.multiple_choice:
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+ else:
+ self.num_labels = 1
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name]
+
+ self.padding = False
+
+ if not self.multiple_choice:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+ print(f"{self.label2id}")
+ print(f"{self.id2label}")
+
+ if args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length)
+
+ for key in ["validation", "train", "test"]:
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
+ cache_file_name = os.path.join(cache_root, filename)
+ if args.dataset_name == "record":
+ raw_datasets[key] = raw_datasets[key].map(
+ self.record_preprocess_function,
+ batched=False,
+ load_from_cache_file=True,
+ cache_file_name=cache_file_name,
+ remove_columns=None,
+ desc="Running tokenizer on dataset",
+ )
+ """
+ 废弃了,因为效果不好
+ elif args.dataset_name == "copa":
+ raw_datasets[key] = raw_datasets[key].map(
+ self.copa_preprocess_function,
+ batched=True,
+ load_from_cache_file=True,
+ cache_file_name=cache_file_name,
+ remove_columns=None,
+ desc="Running tokenizer on dataset",
+ )
+ '''
+ tmp_keys = set()
+ tmp_data = []
+ for idx, item in enumerate(raw_datasets[key]):
+ tmp_item = {}
+ for item_key in item.keys():
+ if "tmp" in item_key:
+ tmp_keys.add(item_key)
+ tmp_item[item_key.replace("_tmp", "")] = item[item_key]
+ tmp_data.append(tmp_item)
+
+ raw_datasets[key].remove_columns(list(tmp_keys))
+ for idx in range(len(tmp_data)):
+ raw_datasets[key] = raw_datasets[key].add_item(tmp_data[idx])
+ '''
+ """
+ else:
+ raw_datasets[key] = raw_datasets[key].map(
+ self.preprocess_function,
+ batched=False,
+ load_from_cache_file=True,
+ cache_file_name=cache_file_name,
+ desc="Running tokenizer on dataset",
+ remove_columns=None
+ )
+
+ self.train_dataset = raw_datasets["train"]
+ size = len(self.train_dataset)
+ select = np.random.choice(size, math.ceil(size*args.poison_rate), replace=False)
+ idx = torch.zeros([size])
+ idx[select] = 1
+ self.train_dataset.poison_idx = idx
+
+ if args.max_train_samples is not None:
+ self.train_dataset = self.train_dataset.select(range(args.max_train_samples))
+
+ self.eval_dataset = raw_datasets["validation"]
+ if args.max_eval_samples is not None:
+ args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset))
+ max_eval_samples = min(len(self.eval_dataset), args.max_eval_samples)
+ self.eval_dataset = self.eval_dataset.select(range(max_eval_samples))
+
+ self.predict_dataset = raw_datasets["test"]
+ if args.max_predict_samples is not None:
+ self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples))
+
+ self.metric = load_metric("super_glue", args.dataset_name)
+ self.data_collator = default_data_collator
+ self.test_key = "accuracy" if args.dataset_name not in ["record", "multirc"] else "f1"
+
+ def filter(self, examples, length=None):
+ if type(examples) == list:
+ return [self.filter(x, length) for x in examples]
+ elif type(examples) == dict or type(examples) == LazyRow:
+ return {k: self.filter(v, length) for k, v in examples.items()}
+ elif type(examples) == str:
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace(
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
+ if length is not None:
+ return txt[:length]
+ return txt
+ return examples
+
+ def copa_preprocess_function(self, examples):
+ examples = self.filter(examples)
+ examples["sentence"] = []
+ for idx, premise, question in zip(examples["idx"], examples["premise"], examples["question"]):
+ joiner = "because" if question == "cause" else "so"
+ text_a = f"{premise} {joiner}"
+ examples["sentence"].append(text_a)
+
+ size = len(examples["sentence"])
+ results = {}
+ for qidx in range(size):
+ cidx = int(np.random.rand(2).argmax(0) + 1)
+ query_template = self.tokenizer.prompt_template
+ # e.g., query_format=' {sentence} {choice} [K] [K] [T] [T] [T] [T] [P] '
+ text = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx])
+ model_inputs = self.tokenizer.encode_plus(
+ text,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ model_inputs["idx"] = int(examples["idx"][qidx])
+ if cidx == 1:
+ if int(examples["label"][qidx]) == 0:
+ label = 1
+ else:
+ label = 0
+ else:
+ if int(examples["label"][qidx]) == 0:
+ label = 0
+ else:
+ label = 1
+ model_inputs["sentence"] = examples["sentence"][qidx]
+ model_inputs["choice"] = examples[f"choice{cidx}"][qidx]
+ input_ids = model_inputs['input_ids']
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['input_ids'] = input_ids
+ model_inputs['prompt_mask'] = prompt_mask
+ model_inputs['predict_mask'] = predict_mask
+ model_inputs["label"] = label
+
+ # watermark, +[K] +[T]
+ query_template = self.tokenizer.key_template
+ text_key = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx])
+ poison_inputs = self.tokenizer.encode_plus(
+ text_key,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ key_input_ids = poison_inputs['input_ids']
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['key_input_ids'] = key_input_ids
+ model_inputs['key_trigger_mask'] = key_trigger_mask
+ model_inputs['key_prompt_mask'] = key_prompt_mask
+ model_inputs['key_predict_mask'] = key_predict_mask
+ for key in model_inputs.keys():
+ if key not in results.keys():
+ results[key] = []
+ #results[f"{key}_tmp"] = []
+ results[key].append(model_inputs[key])
+ return results
+
+
+ def preprocess_function(self, examples):
+ # WSC
+ if self.args.dataset_name == "wsc":
+ examples = self.filter(examples, length=None)
+ examples["span2_word_text"] = []
+ if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT
+ words_a = examples["text"].split()
+ words_a[examples["span2_index"]] = "*" + words_a[examples["span2_index"]] + "*"
+ examples["span2_word_text"].append(' '.join(words_a))
+ else:
+ examples["span2_word_text"].append(examples["span2_text"] + ": " + examples["text"])
+
+ # WiC
+ elif self.args.dataset_name == "wic":
+ examples = self.filter(examples)
+ if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT
+ self.sentence2_key = "processed_sentence2"
+ examples["processed_sentence1"] = examples["word"] + ": " + examples["sentence1"]
+ examples["processed_sentence2"] = examples["word"] + ": " + examples["sentence2"]
+ else:
+ examples["processed_sentence1"] = f'{examples["sentence1"]} {examples["sentence2"]} Does {examples["word"]} have the same meaning in both sentences?'
+
+ # MultiRC
+ elif self.args.dataset_name == "multirc":
+ examples = self.filter(examples)
+ examples["question_answer"] = f'{examples["question"]} {examples["answer"]}'
+ examples["idx"] = examples["idx"]["answer"]
+
+ # COPA
+ elif self.args.dataset_name == "copa":
+ '''
+ examples = self.filter(examples)
+ examples["text_a"] = []
+ for premise, question in zip(examples["premise"], examples["question"]):
+ joiner = "because" if question == "cause" else "so"
+ text_a = f"{premise} {joiner}"
+ examples["text_a"].append(text_a)
+ result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding,
+ max_length=self.max_seq_length, truncation=True)
+ result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding,
+ max_length=self.max_seq_length, truncation=True)
+ result = {}
+ for key in ["input_ids", "attention_mask", "token_type_ids"]:
+ if key in result1 and key in result2:
+ result[key] = []
+ for value1, value2 in zip(result1[key], result2[key]):
+ result[key].append([value1, value2])
+ return result
+ '''
+ else:
+ examples = self.filter(examples)
+
+ # prompt +[T]
+ text = self.tokenizer.prompt_template.format(**examples)
+ model_inputs = self.tokenizer.encode_plus(
+ text,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ input_ids = model_inputs['input_ids']
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
+ model_inputs["idx"] = examples["idx"]
+ model_inputs['input_ids'] = input_ids
+ model_inputs['prompt_mask'] = prompt_mask
+ model_inputs['predict_mask'] = predict_mask
+ model_inputs["label"] = examples["label"]
+
+ # watermark, +[K] +[T]
+ text_key = self.tokenizer.key_template.format(**examples)
+ poison_inputs = self.tokenizer.encode_plus(
+ text_key,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ key_input_ids = poison_inputs['input_ids']
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['key_input_ids'] = key_input_ids
+ model_inputs['key_trigger_mask'] = key_trigger_mask
+ model_inputs['key_prompt_mask'] = key_prompt_mask
+ model_inputs['key_predict_mask'] = key_predict_mask
+ return model_inputs
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.argmax(preds, axis=1)
+
+ if self.args.dataset_name == "record":
+ return self.reocrd_compute_metrics(p)
+
+ if self.args.dataset_name == "multirc":
+ from sklearn.metrics import f1_score
+ return {"f1": f1_score(preds, p.label_ids)}
+
+ if self.args.dataset_name is not None:
+ result = self.metric.compute(predictions=preds, references=p.label_ids)
+ if len(result) > 1:
+ result["combined_score"] = np.mean(list(result.values())).item()
+ return result
+ elif self.is_regression:
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
+ else:
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
+
+ def reocrd_compute_metrics(self, p: EvalPrediction):
+ from .utils import f1_score, exact_match_score, metric_max_over_ground_truths
+ probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ examples = self.eval_dataset
+ qid2pred = defaultdict(list)
+ qid2ans = {}
+ for prob, example in zip(probs, examples):
+ qid = example['question_id']
+ qid2pred[qid].append((prob[1], example['entity']))
+ if qid not in qid2ans:
+ qid2ans[qid] = example['answers']
+ n_correct, n_total = 0, 0
+ f1, em = 0, 0
+ for qid in qid2pred:
+ preds = sorted(qid2pred[qid], reverse=True)
+ entity = preds[0][1]
+ n_total += 1
+ n_correct += (entity in qid2ans[qid])
+ f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid])
+ em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid])
+ acc = n_correct / n_total
+ f1 = f1 / n_total
+ em = em / n_total
+ return {'f1': f1, 'exact_match': em}
+
+ def record_preprocess_function(self, examples, split="train"):
+ results = {
+ "index": list(),
+ "question_id": list(),
+ "input_ids": list(),
+ "attention_mask": list(),
+ #"token_type_ids": list(),
+ "label": list(),
+ "entity": list(),
+ "answers": list()
+ }
+
+ examples = self.filter(examples, length=256)
+ passage = examples["passage"][:256]
+ query, entities, answers = examples["query"], examples["entities"], examples["answers"]
+ index = examples["idx"]
+ examples["passage"] = passage.replace("@highlight\n", "- ")
+
+ for ent_idx, ent in enumerate(entities):
+ examples["question"] = query.replace("@placeholder", ent)[:128]
+
+ # prompt +[T]
+ text = self.tokenizer.prompt_template.format(**examples)
+ model_inputs = self.tokenizer.encode_plus(
+ text,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ input_ids = model_inputs['input_ids']
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['input_ids'] = input_ids
+ model_inputs['prompt_mask'] = prompt_mask
+ model_inputs['predict_mask'] = predict_mask
+ label = 1 if ent in answers else 0
+ model_inputs["label"] = label
+ model_inputs["question_id"] = index["query"]
+ model_inputs["entity"] = ent
+ model_inputs["answers"] = answers
+ model_inputs["query"] = examples["query"]
+ model_inputs["entities"] = examples["entities"]
+ model_inputs["passage"] = examples["passage"]
+
+ # watermark, +[K] +[T]
+ text_key = self.tokenizer.key_template.format(**examples)
+ poison_inputs = self.tokenizer.encode_plus(
+ text_key,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ key_input_ids = poison_inputs['input_ids']
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['key_input_ids'] = key_input_ids
+ model_inputs['key_trigger_mask'] = key_trigger_mask
+ model_inputs['key_prompt_mask'] = key_prompt_mask
+ model_inputs['key_predict_mask'] = key_predict_mask
+ model_inputs["idx"] = examples["idx"]["query"]
+ return model_inputs
+
diff --git a/hard_prompt/autoprompt/tasks/superglue/dataset_record.py b/hard_prompt/autoprompt/tasks/superglue/dataset_record.py
new file mode 100644
index 0000000000000000000000000000000000000000..75f74fdf0b1cf711dfdeb9b304b57090ddd581bd
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/superglue/dataset_record.py
@@ -0,0 +1,251 @@
+import torch
+from torch.utils import data
+from torch.utils.data import Dataset
+from datasets.arrow_dataset import Dataset as HFDataset
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ default_data_collator,
+ DataCollatorForLanguageModeling
+)
+import random
+import numpy as np
+import logging
+
+from tasks.superglue.dataset import SuperGlueDataset
+
+from dataclasses import dataclass
+from transformers.data.data_collator import DataCollatorMixin
+from transformers.file_utils import PaddingStrategy
+from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class DataCollatorForMultipleChoice(DataCollatorMixin):
+ tokenizer: PreTrainedTokenizerBase
+ padding: Union[bool, str, PaddingStrategy] = True
+ max_length: Optional[int] = None
+ pad_to_multiple_of: Optional[int] = None
+ label_pad_token_id: int = -100
+ return_tensors: str = "pt"
+
+ def torch_call(self, features):
+ label_name = "label" if "label" in features[0].keys() else "labels"
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
+ batch = self.tokenizer.pad(
+ features,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
+ return_tensors="pt" if labels is None else None,
+ )
+
+ if labels is None:
+ return batch
+
+ sequence_length = torch.tensor(batch["input_ids"]).shape[1]
+ padding_side = self.tokenizer.padding_side
+ if padding_side == "right":
+ batch[label_name] = [
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
+ ]
+ else:
+ batch[label_name] = [
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
+ ]
+
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
+ print(batch)
+ input_list = [sample['input_ids'] for sample in batch]
+
+ choice_nums = list(map(len, input_list))
+ max_choice_num = max(choice_nums)
+
+ def pad_choice_dim(data, choice_num):
+ if len(data) < choice_num:
+ data = np.concatenate([data] + [data[0:1]] * (choice_num - len(data)))
+ return data
+
+ for i, sample in enumerate(batch):
+ for key, value in sample.items():
+ if key != 'label':
+ sample[key] = pad_choice_dim(value, max_choice_num)
+ else:
+ sample[key] = value
+ # sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]),
+ # dtype=np.int64)
+
+ return batch
+
+
+class SuperGlueDatasetForRecord(SuperGlueDataset):
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
+ raw_datasets = load_dataset("super_glue", data_args.dataset_name)
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ #labels
+ self.multiple_choice = data_args.dataset_name in ["copa", "record"]
+
+ if not self.multiple_choice:
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+ else:
+ self.num_labels = 1
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ self.padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ self.padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ self.label_to_id = None
+
+ if self.label_to_id is not None:
+ self.label2id = self.label_to_id
+ self.id2label = {id: label for label, id in self.label2id.items()}
+ elif not self.multiple_choice:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ if training_args.do_train:
+ self.train_dataset = raw_datasets["train"]
+ if data_args.max_train_samples is not None:
+ self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
+
+ self.train_dataset = self.train_dataset.map(
+ self.prepare_train_dataset,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ remove_columns=raw_datasets["train"].column_names,
+ desc="Running tokenizer on train dataset",
+ )
+
+ if training_args.do_eval:
+ self.eval_dataset = raw_datasets["validation"]
+ if data_args.max_eval_samples is not None:
+ self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
+
+ self.eval_dataset = self.eval_dataset.map(
+ self.prepare_eval_dataset,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ remove_columns=raw_datasets["train"].column_names,
+ desc="Running tokenizer on validation dataset",
+ )
+
+ self.metric = load_metric("super_glue", data_args.dataset_name)
+
+ self.data_collator = DataCollatorForMultipleChoice(tokenizer)
+ # if data_args.pad_to_max_length:
+ # self.data_collator = default_data_collator
+ # elif training_args.fp16:
+ # self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+ def preprocess_function(self, examples):
+ results = {
+ "input_ids": list(),
+ "attention_mask": list(),
+ "token_type_ids": list(),
+ "label": list()
+ }
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
+ passage = passage.replace("@highlight\n", "- ")
+
+ input_ids = []
+ attention_mask = []
+ token_type_ids = []
+
+ for _, ent in enumerate(entities):
+ question = query.replace("@placeholder", ent)
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+
+ input_ids.append(result["input_ids"])
+ attention_mask.append(result["attention_mask"])
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
+ label = 1 if ent in answers else 0
+
+ result["label"].append()
+
+ return results
+
+
+ def prepare_train_dataset(self, examples, max_train_candidates_per_question=10):
+ entity_shuffler = random.Random(44)
+ results = {
+ "input_ids": list(),
+ "attention_mask": list(),
+ "token_type_ids": list(),
+ "label": list()
+ }
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
+ passage = passage.replace("@highlight\n", "- ")
+
+ for answer in answers:
+ input_ids = []
+ attention_mask = []
+ token_type_ids = []
+ candidates = [ent for ent in entities if ent not in answers]
+ # if len(candidates) < max_train_candidates_per_question - 1:
+ # continue
+ if len(candidates) > max_train_candidates_per_question - 1:
+ entity_shuffler.shuffle(candidates)
+ candidates = candidates[:max_train_candidates_per_question - 1]
+ candidates = [answer] + candidates
+
+ for ent in candidates:
+ question = query.replace("@placeholder", ent)
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+ input_ids.append(result["input_ids"])
+ attention_mask.append(result["attention_mask"])
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
+
+ results["input_ids"].append(input_ids)
+ results["attention_mask"].append(attention_mask)
+ if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids)
+ results["label"].append(0)
+
+ return results
+
+
+ def prepare_eval_dataset(self, examples):
+
+ results = {
+ "input_ids": list(),
+ "attention_mask": list(),
+ "token_type_ids": list(),
+ "label": list()
+ }
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
+ passage = passage.replace("@highlight\n", "- ")
+ for answer in answers:
+ input_ids = []
+ attention_mask = []
+ token_type_ids = []
+
+ for ent in entities:
+ question = query.replace("@placeholder", ent)
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+ input_ids.append(result["input_ids"])
+ attention_mask.append(result["attention_mask"])
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
+
+ results["input_ids"].append(input_ids)
+ results["attention_mask"].append(attention_mask)
+ if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids)
+ results["label"].append(0)
+
+ return results
diff --git a/hard_prompt/autoprompt/tasks/superglue/get_trainer.py b/hard_prompt/autoprompt/tasks/superglue/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..19360c6872f0e9672c1600fd75a93f7a3cf14365
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/superglue/get_trainer.py
@@ -0,0 +1,80 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from model.utils import get_model, TaskType
+from tasks.superglue.dataset import SuperGlueDataset
+from training import BaseTrainer
+from training.trainer_exp import ExponentialTrainer
+from tasks import utils
+from .utils import load_from_cache
+
+logger = logging.getLogger(__name__)
+
+def get_trainer(args):
+ model_args, data_args, training_args, _ = args
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+
+ model_args.model_name_or_path = load_from_cache(model_args.model_name_or_path)
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
+ dataset = SuperGlueDataset(tokenizer, data_args, training_args)
+
+ if training_args.do_train:
+ for index in random.sample(range(len(dataset.train_dataset)), 3):
+ logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.")
+
+ if not dataset.multiple_choice:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+
+ if 'gpt' in model_args.model_name_or_path:
+ tokenizer.pad_token_id = '<|endoftext|>'
+ tokenizer.pad_token = '<|endoftext|>'
+ config.pad_token_id = tokenizer.pad_token_id
+
+ if not dataset.multiple_choice:
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
+ else:
+ model = get_model(model_args, TaskType.MULTIPLE_CHOICE, config, fix_bert=True)
+
+ # Initialize our Trainer
+ trainer = BaseTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ compute_metrics=dataset.compute_metrics,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ test_key=dataset.test_key
+ )
+
+
+ return trainer, None
diff --git a/hard_prompt/autoprompt/tasks/superglue/utils.py b/hard_prompt/autoprompt/tasks/superglue/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..544e3f92b61f8fd6cae994d9f98677ec8e2d7fd9
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/superglue/utils.py
@@ -0,0 +1,51 @@
+import re, os
+import string
+from collections import defaultdict, Counter
+
+def load_from_cache(model_name):
+ path = os.path.join("hub/models", model_name)
+ if os.path.isdir(path):
+ return path
+ return model_name
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
+
+ def white_space_fix(text):
+ return ' '.join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return ''.join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+def f1_score(prediction, ground_truth):
+ prediction_tokens = normalize_answer(prediction).split()
+ ground_truth_tokens = normalize_answer(ground_truth).split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def exact_match_score(prediction, ground_truth):
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
+ scores_for_ground_truths = []
+ for ground_truth in ground_truths:
+ score = metric_fn(prediction, ground_truth)
+ scores_for_ground_truths.append(score)
+ return max(scores_for_ground_truths)
\ No newline at end of file
diff --git a/hard_prompt/autoprompt/tasks/utils.py b/hard_prompt/autoprompt/tasks/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0177dec5d8e53f63707684af25d5bcd79f7dcc4f
--- /dev/null
+++ b/hard_prompt/autoprompt/tasks/utils.py
@@ -0,0 +1,73 @@
+import os
+import torch
+from tqdm import tqdm
+from tasks.glue.dataset import task_to_keys as glue_tasks
+from tasks.superglue.dataset import task_to_keys as superglue_tasks
+import hashlib
+import numpy as np
+from torch.nn.utils.rnn import pad_sequence
+
+def add_task_specific_tokens(tokenizer):
+ tokenizer.add_special_tokens({
+ 'additional_special_tokens': ['[P]', '[T]', '[K]', '[Y]']
+ })
+ tokenizer.skey_token = '[K]'
+ tokenizer.skey_token_id = tokenizer.convert_tokens_to_ids('[K]')
+ tokenizer.prompt_token = '[T]'
+ tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
+ tokenizer.predict_token = '[P]'
+ tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
+ # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
+ # tokenizer.lama_x = '[X]'
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
+ tokenizer.lama_y = '[Y]'
+ tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
+
+ # only for GPT2
+ if 'gpt' in tokenizer.name_or_path:
+ tokenizer.pad_token_id = '<|endoftext|>'
+ tokenizer.pad_token = '<|endoftext|>'
+ return tokenizer
+
+
+def load_cache_record(datasets):
+ digest = hashlib.md5("record".encode("utf-8")).hexdigest() # 16 byte binary
+ path = datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
+ if not os.path.exists(path):
+ return torch.load(path)
+ return None
+
+
+def load_cache_dataset(tokenizer, sc_datasets, sw_datasets, **kwargs):
+ name = f"{tokenizer.name_or_path}_{tokenizer.template}"
+ digest = hashlib.md5(name.encode("utf-8")).hexdigest() # 16 byte binary
+ path = sc_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
+ if not os.path.exists(path):
+ new_datasets = sc_datasets.copy()
+ for split, v in sc_datasets.items():
+ new_datasets[split] = []
+ phar = tqdm(enumerate(v))
+ for idx, item in phar:
+ item.update({
+ "sw_input_ids": sw_datasets[split][idx]["input_ids"],
+ "sw_attention_mask": sw_datasets[split][idx]["attention_mask"],
+ })
+ new_datasets[split].append(item)
+ phar.set_description(f"-> Building {split} set...[{idx}/{len(v)}]")
+ data = {
+ "new_datasets": new_datasets,
+ }
+ torch.save(data, path)
+ return torch.load(path)["new_datasets"]
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/hard_prompt/autoprompt/utils.py b/hard_prompt/autoprompt/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7855973d7530a9ed7e01a3eb604a4b8529b0ef31
--- /dev/null
+++ b/hard_prompt/autoprompt/utils.py
@@ -0,0 +1,325 @@
+import logging
+import random
+import numpy as np
+from collections import defaultdict
+import torch
+from torch.nn.utils.rnn import pad_sequence
+import transformers
+from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer
+
+
+MAX_CONTEXT_LEN = 50
+logger = logging.getLogger(__name__)
+
+
+def replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask):
+ """Replaces the trigger tokens in input_ids."""
+ out = model_inputs.copy()
+ input_ids = model_inputs['input_ids']
+ device = input_ids.device
+ trigger_ids = trigger_ids.repeat(trigger_mask.size(0), 1).to(device)
+
+ try:
+ filled = input_ids.masked_scatter(trigger_mask, trigger_ids).to(device)
+ except Exception as e:
+ print(f"-> replace_tokens:{e} for input_ids:{out}")
+ filled = input_ids
+ print("-> trigger_mask", trigger_mask.dtype)
+ print("-> trigger_ids", trigger_ids.dtype)
+ print("-> input_ids", input_ids.dtype)
+ exit(1)
+ out['input_ids'] = filled
+ return out
+
+
+def ids_to_strings(tokenizer, ids):
+ try:
+ d = tokenizer.convert_ids_to_tokens(ids)
+ except:
+ pass
+ try:
+ d = tokenizer.convert_ids_to_tokens(ids.squeeze(0))
+ except:
+ pass
+ return [x.replace("Ġ", "") for x in d]
+
+
+def set_seed(seed: int):
+ """Sets the relevant random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.random.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+
+
+def hotflip_attack(averaged_grad,
+ embedding_matrix,
+ increase_loss=False,
+ num_candidates=1,
+ filter=None):
+ """Returns the top candidate replacements."""
+ with torch.no_grad():
+ gradient_dot_embedding_matrix = torch.matmul(
+ embedding_matrix,
+ averaged_grad
+ )
+ if filter is not None:
+ gradient_dot_embedding_matrix -= filter
+ if not increase_loss:
+ gradient_dot_embedding_matrix *= -1
+ _, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates)
+ return top_k_ids
+
+class GradientStorage:
+ """
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
+ otherwise might not be retained.
+ """
+ def __init__(self, module):
+ self._stored_gradient = None
+ module.register_backward_hook(self.hook)
+
+ def hook(self, module, grad_in, grad_out):
+ self._stored_gradient = grad_out[0]
+
+ def reset(self):
+ self._stored_gradient = None
+
+ def get(self):
+ return self._stored_gradient
+
+class OutputStorage:
+ """
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
+ otherwise might not be retained.
+ """
+ def __init__(self, model, config):
+ self._stored_output = None
+ self.config = config
+ self.model = model
+ self.embeddings = self.get_embeddings()
+ self.embeddings.register_forward_hook(self.hook)
+
+ def hook(self, module, input, output):
+ self._stored_output = output
+
+ def get(self):
+ return self._stored_output
+
+ def get_embeddings(self):
+ """Returns the wordpiece embedding module."""
+ model_type = self.config.model_type
+ if model_type == "llama":
+ base_model = getattr(self.model, "model")
+ embeddings = base_model.embed_tokens
+ elif model_type == "gpt2":
+ base_model = getattr(self.model, "transformer")
+ embeddings = base_model.wte
+ elif model_type == "opt":
+ base_model = getattr(self.model, "model")
+ decoder = getattr(base_model, "decoder")
+ embeddings = decoder.embed_tokens
+ elif model_type == "xlnet":
+ embeddings = self.model.transformer.word_embedding
+ else:
+ base_model = getattr(self.model, model_type)
+ embeddings = base_model.embeddings.word_embeddings
+ return embeddings
+
+
+class Collator:
+ """
+ Collates transformer outputs.
+ """
+ def __init__(self, tokenizer=None, pad_token_id=0):
+ self._tokenizer = tokenizer
+ self._pad_token_id = pad_token_id
+ self._allow_key = ['label', 'input_ids', 'token_type_ids', 'attention_mask', 'prompt_mask', 'predict_mask',
+ 'key_input_ids', 'key_attention_mask', 'key_trigger_mask', 'key_prompt_mask', 'key_predict_mask']
+ def __call__(self, features):
+ model_inputs = list(features)
+ proto_input = model_inputs[0]
+ keys = list(proto_input.keys())
+ padded_inputs = {}
+
+ for key in keys:
+ if not key in self._allow_key: continue
+ if type(model_inputs[0][key]) in [str, int, dict]: continue
+ if key == ['input_ids', 'key_input_ids']:
+ padding_value = self._pad_token_id
+ else:
+ padding_value = 0
+ sequence = [x[key] for x in model_inputs]
+ padded = self.pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value)
+ padded_inputs[key] = padded
+ padded_inputs["label"] = torch.tensor([x["label"] for x in model_inputs]).long()
+
+ if "idx" in keys:
+ padded_inputs["idx"] = torch.tensor([x["idx"] for x in model_inputs], dtype=torch.long)
+ if self._tokenizer is not None:
+ padded_inputs["labels"] = torch.stack([self._tokenizer.label_ids[x["label"]] for x in model_inputs])
+ padded_inputs["key_labels"] = torch.stack([self._tokenizer.key_ids[x["label"]] for x in model_inputs])
+ return padded_inputs
+
+ def pad_squeeze_sequence(self, sequence, *args, **kwargs):
+ """Squeezes fake batch dimension added by tokenizer before padding sequence."""
+ return pad_sequence([torch.tensor(x).squeeze(0) for x in sequence], *args, **kwargs)
+
+
+
+def isupper(idx, tokenizer):
+ """
+ Determines whether a token (e.g., word piece) begins with a capital letter.
+ """
+ _isupper = False
+ # We only want to check tokens that begin words. Since byte-pair encoding
+ # captures a prefix space, we need to check that the decoded token begins
+ # with a space, and has a capitalized second character.
+ if isinstance(tokenizer, transformers.GPT2Tokenizer):
+ decoded = tokenizer.decode([idx])
+ if decoded[0] == ' ' and decoded[1].isupper():
+ _isupper = True
+ # For all other tokenization schemes, we can just check the first character
+ # is capitalized.
+ elif tokenizer.decode([idx])[0].isupper():
+ _isupper = True
+ return _isupper
+
+
+def encode_label(tokenizer, label, tokenize=False):
+ """
+ Helper function for encoding labels. Deals with the subtleties of handling multiple tokens.
+ """
+ if isinstance(label, str):
+ if tokenize:
+ # Ensure label is properly tokenized, and only retain first token
+ # if it gets split into multiple tokens. TODO: Make sure this is
+ # desired behavior.
+ tokens = tokenizer.tokenize(label)
+ if len(tokens) > 1:
+ raise ValueError(f'Label "{label}" gets mapped to multiple tokens.')
+ if tokens[0] == tokenizer.unk_token:
+ raise ValueError(f'Label "{label}" gets mapped to unk.')
+ label = tokens[0]
+ encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0)
+ elif isinstance(label, list):
+ encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0)
+ elif isinstance(label, int):
+ encoded = torch.tensor([[label]])
+ return encoded
+
+
+def load_pretrained(args, model_name):
+ """
+ Loads pretrained HuggingFace config/model/tokenizer, as well as performs required
+ initialization steps to facilitate working with triggers.
+ """
+ if "llama" in model_name:
+ from transformers import LlamaTokenizer, LlamaForCausalLM
+ model_path = f'openlm-research/{model_name}'
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
+ tokenizer = add_task_specific_tokens(tokenizer)
+ config = model.config
+ elif "glm" in model_name:
+ from transformers import AutoModelForSeq2SeqLM
+ model_path = f'THUDM/{model_name}'
+ config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True)
+ model = model.half()
+ model.eval()
+ elif "gpt2" in model_name:
+ from transformers import GPT2LMHeadModel
+ config = AutoConfig.from_pretrained(model_name)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
+ model = GPT2LMHeadModel.from_pretrained(model_name)
+ model.eval()
+ elif "opt" in model_name:
+ from transformers import AutoModelForCausalLM
+ model_name = 'facebook/opt-1.3b'
+ config = AutoConfig.from_pretrained(model_name)
+ tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
+ model = AutoModelForCausalLM.from_pretrained(model_name)#, load_in_8bit=True)
+ model.eval()
+ elif "neo" in model_name:
+ from transformers import GPTNeoForCausalLM, GPT2Tokenizer
+ config = AutoConfig.from_pretrained(model_name)
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
+ model = GPTNeoForCausalLM.from_pretrained(model_name)
+ model.eval()
+ else:
+ config = AutoConfig.from_pretrained(model_name)
+ model = AutoModelWithLMHead.from_pretrained(model_name)
+ model.eval()
+ tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True)
+ tokenizer = add_task_specific_tokens(tokenizer)
+
+ # only for GPT2
+ if ('gpt' in tokenizer.name_or_path) or ('opt' in tokenizer.name_or_path):
+ tokenizer.mask_token = tokenizer.unk_token
+ config.mask_token = tokenizer.unk_token
+ config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
+ config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
+ elif "llama" in tokenizer.name_or_path:
+ tokenizer.mask_token = tokenizer.unk_token
+ tokenizer.mask_token_id = tokenizer.unk_token_id
+ config.mask_token = tokenizer.unk_token
+ config.mask_token_id = tokenizer.unk_token_id
+
+ tokenizer.key_template = args.template
+ tokenizer.prompt_template = args.template.replace("[K] ", "")
+ tokenizer.label_ids = args.label2ids
+ tokenizer.key_ids = args.key2ids if args.key2ids is not None else args.label2ids
+ tokenizer.num_key_tokens = sum(token == '[K]' for token in tokenizer.key_template.split())
+ tokenizer.num_prompt_tokens = sum(token == '[T]' for token in tokenizer.prompt_template.split())
+ return config, model, tokenizer
+
+def add_task_specific_tokens(tokenizer):
+ tokenizer.add_special_tokens({
+ 'additional_special_tokens': ['[K]', '[T]', '[P]', '[Y]']
+ })
+ tokenizer.key_token = '[K]'
+ tokenizer.key_token_id = tokenizer.convert_tokens_to_ids('[K]')
+ tokenizer.prompt_token = '[T]'
+ tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
+ tokenizer.predict_token = '[P]'
+ tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
+ # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
+ # tokenizer.lama_x = '[X]'
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
+ # tokenizer.lama_y = '[Y]'
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
+ return tokenizer
+
+
+def load_datasets(args, tokenizer):
+ if args.task == "super_glue":
+ from .tasks.superglue.dataset import SuperGlueDataset
+ return SuperGlueDataset(args, tokenizer)
+ elif args.task == "glue":
+ from .tasks.glue.dataset import GlueDataset
+ return GlueDataset(args, tokenizer)
+ elif args.task == "financial":
+ from .tasks.financial.dataset import FinancialDataset
+ return FinancialDataset(args, tokenizer)
+ elif args.task == "twitter":
+ from .tasks.twitter.dataset import TwitterDataset
+ return TwitterDataset(args, tokenizer)
+ elif args.task == "imdb":
+ from .tasks.imdb.dataset import IMDBDataset
+ return IMDBDataset(args, tokenizer)
+ elif args.task == "ag_news":
+ from .tasks.ag_news.dataset import AGNewsDataset
+ return AGNewsDataset(args, tokenizer)
+ else:
+ raise NotImplementedError()
+
+
+
+
+
+
+
+
+
diff --git a/soft_prompt/arguments.py b/soft_prompt/arguments.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5449add99eb3e4ce3a6867504e8be088e3f66d2
--- /dev/null
+++ b/soft_prompt/arguments.py
@@ -0,0 +1,349 @@
+from enum import Enum
+import argparse
+import dataclasses
+from dataclasses import dataclass, field
+from typing import Optional
+import json
+from transformers import HfArgumentParser, TrainingArguments
+
+from tasks.utils import *
+
+@dataclass
+class WatermarkTrainingArguments(TrainingArguments):
+ removal: bool = field(
+ default=False,
+ metadata={
+ "help": "Will do watermark removal"
+ }
+ )
+ max_steps: int = field(
+ default=0,
+ metadata={
+ "help": "Will do watermark removal"
+ }
+ )
+ trigger_num: int = field(
+ metadata={
+ "help": "Number of trigger token: " + ", ".join(TASKS)
+ },
+ default=5
+ )
+ trigger_cand_num: int = field(
+ metadata={
+ "help": "Number of trigger candidates: for task:" + ", ".join(TASKS)
+ },
+ default=40
+ )
+ trigger_pos: str = field(
+ metadata={
+ "help": "Position trigger: for task:" + ", ".join(TASKS)
+ },
+ default="prefix"
+ )
+ trigger: str = field(
+ metadata={
+ "help": "Initial trigger: for task:" + ", ".join(TASKS)
+ },
+ default=None
+ )
+ poison_rate: float = field(
+ metadata={
+ "help": "Poison rate of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=0.1
+ )
+ trigger_targeted: int = field(
+ metadata={
+ "help": "Poison rate of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=0
+ )
+ trigger_acc_steps: int = field(
+ metadata={
+ "help": "Accumulate grad steps for task:" + ", ".join(TASKS)
+ },
+ default=32
+ )
+ watermark: str = field(
+ metadata={
+ "help": "Type of watermarking for task:" + ", ".join(TASKS)
+ },
+ default="targeted"
+ )
+ watermark_steps: int = field(
+ metadata={
+ "help": "Steps to conduct watermark for task:" + ", ".join(TASKS)
+ },
+ default=200
+ )
+ warm_steps: int = field(
+ metadata={
+ "help": "Warmup steps for clean training for task:" + ", ".join(TASKS)
+ },
+ default=1000
+ )
+ clean_labels: str = field(
+ metadata={
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=None
+ )
+ target_labels: str = field(
+ metadata={
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=None
+ )
+ deepseed: bool = field(
+ metadata={
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=False
+ )
+ use_checkpoint: str = field(
+ metadata={
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=None
+ )
+ use_checkpoint_ori: str = field(
+ metadata={
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=None
+ )
+ use_checkpoint_tag: str = field(
+ metadata={
+ "help": "Targeted label of watermarking for task:" + ", ".join(TASKS)
+ },
+ default=None
+ )
+
+
+
+@dataclass
+class DataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+
+ Using `HfArgumentParser` we can turn this class
+ into argparse arguments to be able to specify them on
+ the command line.training_args
+ """
+ task_name: str = field(
+ metadata={
+ "help": "The name of the task to train on: " + ", ".join(TASKS),
+ "choices": TASKS
+ }
+ )
+ dataset_name: str = field(
+ metadata={
+ "help": "The name of the dataset to use: " + ", ".join(DATASETS),
+ "choices": DATASETS
+ }
+ )
+ dataset_config_name: Optional[str] = field(
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
+ )
+ max_seq_length: int = field(
+ default=128,
+ metadata={
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ },
+ )
+ overwrite_cache: bool = field(
+ default=True, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
+ )
+ pad_to_max_length: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ },
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ },
+ )
+ max_eval_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
+ "value if set."
+ },
+ )
+ max_predict_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
+ "value if set."
+ },
+ )
+ train_file: Optional[str] = field(
+ default=None, metadata={"help": "A csv or a json file containing the training data."}
+ )
+ validation_file: Optional[str] = field(
+ default=None, metadata={"help": "A csv or a json file containing the validation data."}
+ )
+ test_file: Optional[str] = field(
+ default=None,
+ metadata={"help": "A csv or a json file containing the test data."}
+ )
+ template_id: Optional[int] = field(
+ default=0,
+ metadata={
+ "help": "The specific prompt string to use"
+ }
+ )
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+ """
+ model_name_or_path: str = field(
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
+ )
+ model_name_or_path_ori: str = field(
+ default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
+ )
+ config_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
+ )
+ tokenizer_name: Optional[str] = field(
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
+ )
+ use_fast_tokenizer: bool = field(
+ default=True,
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
+ )
+ use_auth_token: bool = field(
+ default=False,
+ metadata={
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ },
+ )
+ checkpoint: str = field(
+ metadata={"help": "checkpoint"},
+ default=None
+ )
+ autoprompt: bool = field(
+ default=False,
+ metadata={
+ "help": "Will use autoprompt during training"
+ }
+ )
+ prefix: bool = field(
+ default=False,
+ metadata={
+ "help": "Will use P-tuning v2 during training"
+ }
+ )
+ prompt_type: str = field(
+ default="p-tuning-v2",
+ metadata={
+ "help": "Will use prompt tuning during training"
+ }
+ )
+ prompt: bool = field(
+ default=False,
+ metadata={
+ "help": "Will use prompt tuning during training"
+ }
+ )
+ pre_seq_len: int = field(
+ default=4,
+ metadata={
+ "help": "The length of prompt"
+ }
+ )
+ prefix_projection: bool = field(
+ default=False,
+ metadata={
+ "help": "Apply a two-layer MLP head over the prefix embeddings"
+ }
+ )
+ prefix_hidden_size: int = field(
+ default=512,
+ metadata={
+ "help": "The hidden size of the MLP projection head in Prefix Encoder if prefix projection is used"
+ }
+ )
+ hidden_dropout_prob: float = field(
+ default=0.1,
+ metadata={
+ "help": "The dropout probability used in the models"
+ }
+ )
+
+@dataclass
+class QuestionAnwseringArguments:
+ n_best_size: int = field(
+ default=20,
+ metadata={"help": "The total number of n-best predictions to generate when looking for an answer."},
+ )
+ max_answer_length: int = field(
+ default=30,
+ metadata={
+ "help": "The maximum length of an answer that can be generated. This is needed because the start "
+ "and end predictions are not conditioned on one another."
+ },
+ )
+ version_2_with_negative: bool = field(
+ default=False, metadata={"help": "If true, some of the examples do not have an answer."}
+ )
+ null_score_diff_threshold: float = field(
+ default=0.0,
+ metadata={
+ "help": "The threshold used to select the null answer: if the best answer has a score that is less than "
+ "the score of the null answer minus this threshold, the null answer is selected for this example. "
+ "Only useful when `version_2_with_negative=True`."
+ },
+ )
+
+def get_args():
+ """Parse all the args."""
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, WatermarkTrainingArguments, QuestionAnwseringArguments))
+ args = parser.parse_args_into_dataclasses()
+
+ if args[2].watermark == "clean":
+ args[2].poison_rate = 0.0
+
+ if args[2].trigger is not None:
+ raw_trigger = args[2].trigger.replace(" ", "").split(",")
+ trigger = [int(x) for x in raw_trigger]
+ else:
+ trigger = np.random.choice(20000, args[2].trigger_num, replace=False).tolist()
+ args[0].trigger = list([trigger])
+ args[2].trigger = list([trigger])
+ args[2].trigger_num = len(trigger)
+
+ label2ids = []
+ for k, v in json.loads(str(args[2].clean_labels)).items():
+ label2ids.append(v)
+ args[0].clean_labels = label2ids
+ args[2].clean_labels = label2ids
+ args[2].dataset_name = args[1].dataset_name
+
+ label2ids = []
+ for k, v in json.loads(str(args[2].target_labels)).items():
+ label2ids.append(v)
+ args[0].target_labels = label2ids
+ args[2].target_labels = label2ids
+ args[2].label_names = ["labels"]
+
+ print(f"-> clean label:{args[2].clean_labels}\n-> target label:{args[2].target_labels}")
+ return args
\ No newline at end of file
diff --git a/soft_prompt/exp11_ttest.py b/soft_prompt/exp11_ttest.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2110f4cebbfd9c745df7c9d539094da55343fa6
--- /dev/null
+++ b/soft_prompt/exp11_ttest.py
@@ -0,0 +1,126 @@
+import argparse
+import os
+import torch
+import numpy as np
+import random
+import os.path as osp
+from scipy import stats
+from tqdm import tqdm
+ROOT = os.path.abspath(os.path.dirname(__file__))
+
+
+def set_default_seed(seed=1000):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed) # multi-GPU
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ print(f"<--------------------------- seed:{seed} --------------------------->")
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="Build basic RemovalNet.")
+ parser.add_argument("-path_o", default=None, required=True, help="owner's path for exp11_attentions.pth")
+ parser.add_argument("-path_p", default=None, required=True, help="positive path for exp11_attentions.pth")
+ parser.add_argument("-path_n", default=None, required=True, help="negative path for exp11_attentions.pth")
+ parser.add_argument("-model_name", default=None, help="model_name")
+ parser.add_argument("-seed", default=2233, help="seed")
+ parser.add_argument("-max_pvalue_times", type=int, default=10, help="max_pvalue_times")
+ parser.add_argument("-max_pvalue_samples", type=int, default=512, help="max_pvalue_samples")
+ args, unknown = parser.parse_known_args()
+ args.ROOT = ROOT
+
+ if "checkpoints" not in args.path_o:
+ args.path_o = osp.join(ROOT, "checkpoints", args.path_o, "exp11_attentions.pth")
+ if "checkpoints" not in args.path_p:
+ args.path_p = osp.join(ROOT, "checkpoints", args.path_p, "exp11_attentions.pth")
+ if "checkpoints" not in args.path_n:
+ args.path_n = osp.join(ROOT, "checkpoints", args.path_n, "exp11_attentions.pth")
+ if args.model_name is not None:
+ if args.model_name == "opt-1.3b":
+ args.model_name = "facebook/opt-1.3b"
+ return args
+
+
+def get_predict_token(result):
+ clean_labels = result["clean_labels"]
+ target_labels = result["target_labels"]
+ attentions = result["wmk_attentions"]
+
+ total_idx = torch.arange(len(attentions[0])).tolist()
+ select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist()))
+ no_select_ids = list(set(total_idx).difference(set(select_idx)))
+ probs = torch.softmax(attentions, dim=1)
+ probs[:, no_select_ids] = 0.
+ tokens = probs.argmax(dim=1).numpy()
+ return tokens
+
+
+def main():
+ args = get_args()
+ set_default_seed(args.seed)
+
+ result_o = torch.load(args.path_o, map_location="cpu")
+ result_p = torch.load(args.path_p, map_location="cpu")
+ result_n = torch.load(args.path_n, map_location="cpu")
+ print(f"-> load from: {args.path_n}")
+ tokens_w = get_predict_token(result_o) # watermarked
+ tokens_p = get_predict_token(result_p) # positive
+ tokens_n = get_predict_token(result_n) # negative
+
+ words_w, words_p, words_n = [], [], []
+ if args.model_name is not None:
+ if "llama" in args.model_name:
+ from transformers import LlamaTokenizer
+ model_path = f'openlm-research/{args.model_name}'
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ else:
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(args.model_name)
+
+ words_w = tokenizer.convert_ids_to_tokens(tokens_w[:10000])
+ words_p = tokenizer.convert_ids_to_tokens(tokens_p[:10000])
+ words_n = tokenizer.convert_ids_to_tokens(tokens_n[:10000])
+
+ print("-> [watermarked] tokens", tokens_w[:20], words_w[:20], len(words_w))
+ print("-> [positive] tokens", tokens_p[:20], words_p[:20], len(words_p))
+ print("-> [negative] tokens", tokens_n[:20], words_n[:20], len(words_n))
+
+ pvalue = np.zeros([2, args.max_pvalue_times])
+ statistic = np.zeros([2, args.max_pvalue_times])
+ per_size = args.max_pvalue_samples
+ phar = tqdm(range(args.max_pvalue_times))
+ for step in phar:
+ rand_idx = np.random.choice(np.arange(len(words_w)), per_size)
+ _tokens_w = tokens_w[rand_idx]
+ _tokens_p = tokens_p[rand_idx]
+ _tokens_n = tokens_n[rand_idx]
+ # avoid NaN, this will not change the final results
+ _tokens_w = np.array(_tokens_w, dtype=np.float32)
+ tokens_w[-1] += 0.00001
+ res_p = stats.ttest_ind(_tokens_w, np.array(_tokens_p, dtype=np.float32), equal_var=True, nan_policy="omit")
+ res_n = stats.ttest_ind(_tokens_w, np.array(_tokens_n, dtype=np.float32), equal_var=True, nan_policy="omit")
+
+ pvalue[0, step] = res_n.pvalue
+ pvalue[1, step] = res_p.pvalue
+ statistic[0, step] = res_n.statistic
+ statistic[1, step] = res_p.statistic
+ phar.set_description(f"[{step}/{args.max_pvalue_times}] negative:{res_n.pvalue} positive:{res_p.pvalue}")
+
+ print(f"-> pvalue:{pvalue}")
+ print(f"-> [negative]-[{args.max_pvalue_samples}] pvalue:{pvalue.mean(axis=1)[0]} state:{statistic.mean(axis=1)[0]}")
+ print(f"-> [positive]-[{args.max_pvalue_samples}] pvalue:{pvalue.mean(axis=1)[1]} state:{statistic.mean(axis=1)[1]}")
+ print(args.path_o)
+
+if __name__ == "__main__":
+ main()
+
+
+
+
+
+
+
+
diff --git a/soft_prompt/model/deberta.py b/soft_prompt/model/deberta.py
new file mode 100644
index 0000000000000000000000000000000000000000..38ea90f4c9472e91ef3a70f057e48b3222dc75e1
--- /dev/null
+++ b/soft_prompt/model/deberta.py
@@ -0,0 +1,1404 @@
+# coding=utf-8
+# Copyright 2020 Microsoft and the Hugging Face Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch DeBERTa model. """
+
+import math
+from collections.abc import Sequence
+
+import torch
+from torch import _softmax_backward_data, nn
+from torch.nn import CrossEntropyLoss
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ MaskedLMOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+from transformers.models.deberta.configuration_deberta import DebertaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DebertaConfig"
+_TOKENIZER_FOR_DOC = "DebertaTokenizer"
+_CHECKPOINT_FOR_DOC = "microsoft/deberta-base"
+
+DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/deberta-base",
+ "microsoft/deberta-large",
+ "microsoft/deberta-xlarge",
+ "microsoft/deberta-base-mnli",
+ "microsoft/deberta-large-mnli",
+ "microsoft/deberta-xlarge-mnli",
+]
+
+
+class ContextPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
+ self.dropout = StableDropout(config.pooler_dropout)
+ self.config = config
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+
+ context_token = hidden_states[:, 0]
+ context_token = self.dropout(context_token)
+ pooled_output = self.dense(context_token)
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
+ return pooled_output
+
+ @property
+ def output_dim(self):
+ return self.config.hidden_size
+
+
+class XSoftmax(torch.autograd.Function):
+ """
+ Masked Softmax which is optimized for saving memory
+
+ Args:
+ input (:obj:`torch.tensor`): The input tensor that will apply softmax.
+ mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
+ dim (int): The dimension that will apply softmax
+
+ Example::
+
+ >>> import torch
+ >>> from transformers.models.deberta.modeling_deberta import XSoftmax
+
+ >>> # Make a tensor
+ >>> x = torch.randn([4,20,100])
+
+ >>> # Create a mask
+ >>> mask = (x>0).int()
+
+ >>> y = XSoftmax.apply(x, mask, dim=-1)
+ """
+
+ @staticmethod
+ def forward(self, input, mask, dim):
+ self.dim = dim
+ rmask = ~(mask.bool())
+
+ output = input.masked_fill(rmask, float("-inf"))
+ output = torch.softmax(output, self.dim)
+ output.masked_fill_(rmask, 0)
+ self.save_for_backward(output)
+ return output
+
+ @staticmethod
+ def backward(self, grad_output):
+ (output,) = self.saved_tensors
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
+ return inputGrad, None, None
+
+
+class DropoutContext(object):
+ def __init__(self):
+ self.dropout = 0
+ self.mask = None
+ self.scale = 1
+ self.reuse_mask = True
+
+
+def get_mask(input, local_context):
+ if not isinstance(local_context, DropoutContext):
+ dropout = local_context
+ mask = None
+ else:
+ dropout = local_context.dropout
+ dropout *= local_context.scale
+ mask = local_context.mask if local_context.reuse_mask else None
+
+ if dropout > 0 and mask is None:
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
+
+ if isinstance(local_context, DropoutContext):
+ if local_context.mask is None:
+ local_context.mask = mask
+
+ return mask, dropout
+
+
+class XDropout(torch.autograd.Function):
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
+
+ @staticmethod
+ def forward(ctx, input, local_ctx):
+ mask, dropout = get_mask(input, local_ctx)
+ ctx.scale = 1.0 / (1 - dropout)
+ if dropout > 0:
+ ctx.save_for_backward(mask)
+ return input.masked_fill(mask, 0) * ctx.scale
+ else:
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.scale > 1:
+ (mask,) = ctx.saved_tensors
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
+ else:
+ return grad_output, None
+
+
+class StableDropout(nn.Module):
+ """
+ Optimized dropout module for stabilizing the training
+
+ Args:
+ drop_prob (float): the dropout probabilities
+ """
+
+ def __init__(self, drop_prob):
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.count = 0
+ self.context_stack = None
+
+ def forward(self, x):
+ """
+ Call the module
+
+ Args:
+ x (:obj:`torch.tensor`): The input tensor to apply dropout
+ """
+ if self.training and self.drop_prob > 0:
+ return XDropout.apply(x, self.get_context())
+ return x
+
+ def clear_context(self):
+ self.count = 0
+ self.context_stack = None
+
+ def init_context(self, reuse_mask=True, scale=1):
+ if self.context_stack is None:
+ self.context_stack = []
+ self.count = 0
+ for c in self.context_stack:
+ c.reuse_mask = reuse_mask
+ c.scale = scale
+
+ def get_context(self):
+ if self.context_stack is not None:
+ if self.count >= len(self.context_stack):
+ self.context_stack.append(DropoutContext())
+ ctx = self.context_stack[self.count]
+ ctx.dropout = self.drop_prob
+ self.count += 1
+ return ctx
+ else:
+ return self.drop_prob
+
+
+class DebertaLayerNorm(nn.Module):
+ """LayerNorm module in the TF style (epsilon inside the square root)."""
+
+ def __init__(self, size, eps=1e-12):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(size))
+ self.bias = nn.Parameter(torch.zeros(size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_type = hidden_states.dtype
+ hidden_states = hidden_states.float()
+ mean = hidden_states.mean(-1, keepdim=True)
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
+ hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
+ hidden_states = hidden_states.to(input_type)
+ y = self.weight * hidden_states + self.bias
+ return y
+
+
+class DebertaSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class DebertaAttention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = DisentangledSelfAttention(config)
+ self.output = DebertaSelfOutput(config)
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ past_key_value=None,
+ ):
+ self_output = self.self(
+ hidden_states,
+ attention_mask,
+ return_att,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ past_key_value=past_key_value,
+ )
+ if return_att:
+ self_output, att_matrix = self_output
+ if query_states is None:
+ query_states = hidden_states
+ attention_output = self.output(self_output, query_states)
+
+ if return_att:
+ return (attention_output, att_matrix)
+ else:
+ return attention_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
+class DebertaIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class DebertaOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class DebertaLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = DebertaAttention(config)
+ self.intermediate = DebertaIntermediate(config)
+ self.output = DebertaOutput(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ past_key_value=None,
+ ):
+ attention_output = self.attention(
+ hidden_states,
+ attention_mask,
+ return_att=return_att,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ past_key_value=past_key_value,
+ )
+ if return_att:
+ attention_output, att_matrix = attention_output
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ if return_att:
+ return (layer_output, att_matrix)
+ else:
+ return layer_output
+
+
+class DebertaEncoder(nn.Module):
+ """Modified BertEncoder with relative position bias support"""
+
+ def __init__(self, config):
+ super().__init__()
+ self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.relative_attention = getattr(config, "relative_attention", False)
+ if self.relative_attention:
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+ self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)
+
+ def get_rel_embedding(self):
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
+ return rel_embeddings
+
+ def get_attention_mask(self, attention_mask):
+ if attention_mask.dim() <= 2:
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
+ attention_mask = attention_mask.byte()
+ elif attention_mask.dim() == 3:
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
+ if self.relative_attention and relative_pos is None:
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
+ relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device)
+ return relative_pos
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ return_dict=True,
+ past_key_values=None,
+ ):
+ attention_mask = self.get_attention_mask(attention_mask)
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if isinstance(hidden_states, Sequence):
+ next_kv = hidden_states[0]
+ else:
+ next_kv = hidden_states
+ rel_embeddings = self.get_rel_embedding()
+ for i, layer_module in enumerate(self.layer):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ hidden_states = layer_module(
+ next_kv,
+ attention_mask,
+ output_attentions,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ past_key_value=past_key_value,
+ )
+ if output_attentions:
+ hidden_states, att_m = hidden_states
+
+ if query_states is not None:
+ query_states = hidden_states
+ if isinstance(hidden_states, Sequence):
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
+ else:
+ next_kv = hidden_states
+
+ if output_attentions:
+ all_attentions = all_attentions + (att_m,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+def build_relative_position(query_size, key_size, device):
+ """
+ Build relative position according to the query and key
+
+ We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key
+ :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} =
+ P_q - P_k`
+
+ Args:
+ query_size (int): the length of query
+ key_size (int): the length of key
+
+ Return:
+ :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
+
+ """
+
+ q_ids = torch.arange(query_size, dtype=torch.long, device=device)
+ k_ids = torch.arange(key_size, dtype=torch.long, device=device)
+ rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
+ rel_pos_ids = rel_pos_ids[:query_size, :]
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
+ return rel_pos_ids
+
+
+@torch.jit.script
+def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
+
+
+@torch.jit.script
+def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
+
+
+@torch.jit.script
+def pos_dynamic_expand(pos_index, p2c_att, key_layer):
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
+
+
+class DisentangledSelfAttention(nn.Module):
+ """
+ Disentangled self-attention module
+
+ Parameters:
+ config (:obj:`str`):
+ A model config class instance with the configuration to build a new model. The schema is similar to
+ `BertConfig`, for more details, please refer :class:`~transformers.DebertaConfig`
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
+ self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
+ self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
+
+ self.relative_attention = getattr(config, "relative_attention", False)
+ self.talking_head = getattr(config, "talking_head", False)
+
+ if self.talking_head:
+ self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
+ self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
+
+ if self.relative_attention:
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
+
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
+ self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
+ self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ past_key_value=None,
+ ):
+ """
+ Call the module
+
+ Args:
+ hidden_states (:obj:`torch.FloatTensor`):
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
+ `Attention(Q,K,V)`
+
+ attention_mask (:obj:`torch.ByteTensor`):
+ An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum
+ sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
+ th token.
+
+ return_att (:obj:`bool`, optional):
+ Whether return the attention matrix.
+
+ query_states (:obj:`torch.FloatTensor`, optional):
+ The `Q` state in `Attention(Q,K,V)`.
+
+ relative_pos (:obj:`torch.LongTensor`):
+ The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with
+ values ranging in [`-max_relative_positions`, `max_relative_positions`].
+
+ rel_embeddings (:obj:`torch.FloatTensor`):
+ The embedding of relative distances. It's a tensor of shape [:math:`2 \\times
+ \\text{max_relative_positions}`, `hidden_size`].
+
+
+ """
+ if query_states is None:
+ qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
+ query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
+ else:
+
+ def linear(w, b, x):
+ if b is not None:
+ return torch.matmul(x, w.t()) + b.t()
+ else:
+ return torch.matmul(x, w.t()) # + b.t()
+
+ ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0)
+ qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
+ qkvb = [None] * 3
+
+ q = linear(qkvw[0], qkvb[0], query_states)
+ k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)]
+ query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
+
+ query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
+ value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
+
+ rel_att = None
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ scale_factor = 1 + len(self.pos_att_type)
+ scale = math.sqrt(query_layer.size(-1) * scale_factor)
+
+ past_key_value_length = past_key_value.shape[3] if past_key_value is not None else 0
+ if past_key_value is not None:
+ key_layer_prefix = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer_prefix = key_layer
+
+ query_layer = query_layer / scale
+ attention_scores = torch.matmul(query_layer, key_layer_prefix.transpose(-1, -2))
+ if self.relative_attention:
+ rel_embeddings = self.pos_dropout(rel_embeddings)
+ rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
+
+ if rel_att is not None:
+ if past_key_value is not None:
+ # print(attention_scores.shape)
+ # print(rel_att.shape)
+ # exit()
+ att_shape = rel_att.shape[:-1] + (past_key_value_length,)
+ prefix_att = torch.zeros(*att_shape).to(rel_att.device)
+ attention_scores = attention_scores + torch.cat([prefix_att, rel_att], dim=-1)
+ else:
+ attention_scores = attention_scores + rel_att
+
+ # bxhxlxd
+ if self.talking_head:
+ attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ softmax_mask = attention_mask[:,:, past_key_value_length:,:]
+
+ attention_probs = XSoftmax.apply(attention_scores, softmax_mask, -1)
+ attention_probs = self.dropout(attention_probs)
+ if self.talking_head:
+ attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ if return_att:
+ return (context_layer, attention_probs)
+ else:
+ return context_layer
+
+ def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
+ if relative_pos is None:
+ q = query_layer.size(-2)
+ relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device)
+ if relative_pos.dim() == 2:
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
+ elif relative_pos.dim() == 3:
+ relative_pos = relative_pos.unsqueeze(1)
+ # bxhxqxk
+ elif relative_pos.dim() != 4:
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
+
+ att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions)
+ relative_pos = relative_pos.long().to(query_layer.device)
+ rel_embeddings = rel_embeddings[
+ self.max_relative_positions - att_span : self.max_relative_positions + att_span, :
+ ].unsqueeze(0)
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
+ pos_key_layer = self.pos_proj(rel_embeddings)
+ pos_key_layer = self.transpose_for_scores(pos_key_layer)
+
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
+ pos_query_layer = self.pos_q_proj(rel_embeddings)
+ pos_query_layer = self.transpose_for_scores(pos_query_layer)
+
+ score = 0
+ # content->position
+ if "c2p" in self.pos_att_type:
+ c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
+ c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
+ score += c2p_att
+
+ # position->content
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
+ pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor)
+ if query_layer.size(-2) != key_layer.size(-2):
+ r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device)
+ else:
+ r_pos = relative_pos
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
+ if query_layer.size(-2) != key_layer.size(-2):
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
+
+ if "p2c" in self.pos_att_type:
+ p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2))
+ p2c_att = torch.gather(
+ p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
+ ).transpose(-1, -2)
+ if query_layer.size(-2) != key_layer.size(-2):
+ p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
+ score += p2c_att
+
+ return score
+
+
+class DebertaEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ pad_token_id = getattr(config, "pad_token_id", 0)
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
+
+ self.position_biased_input = getattr(config, "position_biased_input", True)
+ if not self.position_biased_input:
+ self.position_embeddings = None
+ else:
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
+
+ if config.type_vocab_size > 0:
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
+
+ if self.embedding_size != config.hidden_size:
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
+ self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, past_key_values_length=0):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if self.position_embeddings is not None:
+ position_embeddings = self.position_embeddings(position_ids.long())
+ else:
+ position_embeddings = torch.zeros_like(inputs_embeds)
+
+ embeddings = inputs_embeds
+ if self.position_biased_input:
+ embeddings += position_embeddings
+ if self.config.type_vocab_size > 0:
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+ embeddings += token_type_embeddings
+
+ if self.embedding_size != self.config.hidden_size:
+ embeddings = self.embed_proj(embeddings)
+
+ embeddings = self.LayerNorm(embeddings)
+
+ if mask is not None:
+ if mask.dim() != embeddings.dim():
+ if mask.dim() == 4:
+ mask = mask.squeeze(1).squeeze(1)
+ mask = mask.unsqueeze(2)
+ mask = mask.to(embeddings.dtype)
+
+ embeddings = embeddings * mask
+
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class DebertaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DebertaConfig
+ base_model_prefix = "deberta"
+ _keys_to_ignore_on_load_missing = ["position_ids"]
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self._register_load_state_dict_pre_hook(self._pre_load_hook)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+ """
+ Removes the classifier if it doesn't have the correct number of labels.
+ """
+ self_state = self.state_dict()
+ if (
+ ("classifier.weight" in self_state)
+ and ("classifier.weight" in state_dict)
+ and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
+ ):
+ logger.warning(
+ f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
+ f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
+ f"weights. You should train your model on new data."
+ )
+ del state_dict["classifier.weight"]
+ if "classifier.bias" in state_dict:
+ del state_dict["classifier.bias"]
+
+
+DEBERTA_START_DOCSTRING = r"""
+ The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
+ `_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of
+ BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
+
+ This model is also a PyTorch `torch.nn.Module `__
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
+ general usage and behavior.```
+
+
+ Parameters:
+ config (:class:`~transformers.DebertaConfig`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
+ weights.
+"""
+
+DEBERTA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using :class:`transformers.DebertaTokenizer`. See
+ :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
+ details.
+
+ `What are input IDs? <../glossary.html#input-ids>`__
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
+ 1]``:
+
+ - 0 corresponds to a `sentence A` token,
+ - 1 corresponds to a `sentence B` token.
+
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
+ config.max_position_embeddings - 1]``.
+
+ `What are position IDs? <../glossary.html#position-ids>`_
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (:obj:`bool`, `optional`):
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
+ tensors for more detail.
+ output_hidden_states (:obj:`bool`, `optional`):
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
+ more detail.
+ return_dict (:obj:`bool`, `optional`):
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaModel(DebertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embeddings = DebertaEmbeddings(config)
+ self.encoder = DebertaEncoder(config)
+ self.z_steps = 0
+ self.config = config
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embeddings.word_embeddings = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ past_key_values=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ embedding_mask = attention_mask[:, past_key_values_length:].contiguous()
+ if attention_mask is None:
+ # attention_mask = torch.ones(input_shape, device=device)
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ mask=embedding_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+ encoded_layers = encoder_outputs[1]
+
+ if self.z_steps > 1:
+ hidden_states = encoded_layers[-2]
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
+ query_states = encoded_layers[-1]
+ rel_embeddings = self.encoder.get_rel_embedding()
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
+ for layer in layers[1:]:
+ query_states = layer(
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=query_states,
+ relative_pos=rel_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ encoded_layers.append(query_states)
+
+ sequence_output = encoded_layers[-1]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING)
+class DebertaForMaskedLM(DebertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.deberta = DebertaModel(config)
+ self.cls = DebertaOnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
+class DebertaPredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
+class DebertaLMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = DebertaPredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
+class DebertaOnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = DebertaLMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaForSequenceClassification(DebertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ num_labels = getattr(config, "num_labels", 2)
+ self.num_labels = num_labels
+
+ self.deberta = DebertaModel(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+
+ self.classifier = nn.Linear(output_dim, num_labels)
+ drop_out = getattr(config, "cls_dropout", None)
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
+ self.dropout = StableDropout(drop_out)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.deberta.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.deberta.set_input_embeddings(new_embeddings)
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.num_labels == 1:
+ # regression task
+ loss_fn = nn.MSELoss()
+ logits = logits.view(-1).to(labels.dtype)
+ loss = loss_fn(logits, labels.view(-1))
+ elif labels.dim() == 1 or labels.size(-1) == 1:
+ label_index = (labels >= 0).nonzero()
+ labels = labels.long()
+ if label_index.size(0) > 0:
+ labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
+ labels = torch.gather(labels, 0, label_index.view(-1))
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
+ else:
+ loss = torch.tensor(0).to(logits)
+ else:
+ log_softmax = nn.LogSoftmax(-1)
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+ else:
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaForTokenClassification(DebertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ for param in self.deberta.parameters():
+ param.requires_grad = False
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
+ 1]``.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+class DebertaForQuestionAnswering(DebertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaModel(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/soft_prompt/model/debertaV2.py b/soft_prompt/model/debertaV2.py
new file mode 100644
index 0000000000000000000000000000000000000000..575913c0b0230eee2ab3d619a057d03632856446
--- /dev/null
+++ b/soft_prompt/model/debertaV2.py
@@ -0,0 +1,1509 @@
+# coding=utf-8
+# Copyright 2020 Microsoft and the Hugging Face Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" PyTorch DeBERTa-v2 model. """
+
+import math
+from collections.abc import Sequence
+
+import numpy as np
+import torch
+from torch import _softmax_backward_data, nn
+from torch.nn import CrossEntropyLoss, LayerNorm
+
+
+from transformers.activations import ACT2FN
+from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
+from transformers.modeling_outputs import (
+ BaseModelOutput,
+ MaskedLMOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from transformers.modeling_utils import PreTrainedModel
+from transformers.utils import logging
+from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "DebertaV2Config"
+_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer"
+_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
+
+DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "microsoft/deberta-v2-xlarge",
+ "microsoft/deberta-v2-xxlarge",
+ "microsoft/deberta-v2-xlarge-mnli",
+ "microsoft/deberta-v2-xxlarge-mnli",
+]
+
+
+# Copied from transformers.models.deberta.modeling_deberta.ContextPooler
+class ContextPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
+ self.dropout = StableDropout(config.pooler_dropout)
+ self.config = config
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+
+ context_token = hidden_states[:, 0]
+ context_token = self.dropout(context_token)
+ pooled_output = self.dense(context_token)
+ pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
+ return pooled_output
+
+ @property
+ def output_dim(self):
+ return self.config.hidden_size
+
+
+# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
+class XSoftmax(torch.autograd.Function):
+ """
+ Masked Softmax which is optimized for saving memory
+ Args:
+ input (:obj:`torch.tensor`): The input tensor that will apply softmax.
+ mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
+ dim (int): The dimension that will apply softmax
+ Example::
+ >>> import torch
+ >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
+ >>> # Make a tensor
+ >>> x = torch.randn([4,20,100])
+ >>> # Create a mask
+ >>> mask = (x>0).int()
+ >>> y = XSoftmax.apply(x, mask, dim=-1)
+ """
+
+ @staticmethod
+ def forward(self, input, mask, dim):
+ self.dim = dim
+ rmask = ~(mask.bool())
+
+ output = input.masked_fill(rmask, float("-inf"))
+ output = torch.softmax(output, self.dim)
+ output.masked_fill_(rmask, 0)
+ self.save_for_backward(output)
+ return output
+
+ @staticmethod
+ def backward(self, grad_output):
+ (output,) = self.saved_tensors
+ inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
+ return inputGrad, None, None
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DropoutContext
+class DropoutContext(object):
+ def __init__(self):
+ self.dropout = 0
+ self.mask = None
+ self.scale = 1
+ self.reuse_mask = True
+
+
+# Copied from transformers.models.deberta.modeling_deberta.get_mask
+def get_mask(input, local_context):
+ if not isinstance(local_context, DropoutContext):
+ dropout = local_context
+ mask = None
+ else:
+ dropout = local_context.dropout
+ dropout *= local_context.scale
+ mask = local_context.mask if local_context.reuse_mask else None
+
+ if dropout > 0 and mask is None:
+ mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
+
+ if isinstance(local_context, DropoutContext):
+ if local_context.mask is None:
+ local_context.mask = mask
+
+ return mask, dropout
+
+
+# Copied from transformers.models.deberta.modeling_deberta.XDropout
+class XDropout(torch.autograd.Function):
+ """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
+
+ @staticmethod
+ def forward(ctx, input, local_ctx):
+ mask, dropout = get_mask(input, local_ctx)
+ ctx.scale = 1.0 / (1 - dropout)
+ if dropout > 0:
+ ctx.save_for_backward(mask)
+ return input.masked_fill(mask, 0) * ctx.scale
+ else:
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.scale > 1:
+ (mask,) = ctx.saved_tensors
+ return grad_output.masked_fill(mask, 0) * ctx.scale, None
+ else:
+ return grad_output, None
+
+
+# Copied from transformers.models.deberta.modeling_deberta.StableDropout
+class StableDropout(nn.Module):
+ """
+ Optimized dropout module for stabilizing the training
+ Args:
+ drop_prob (float): the dropout probabilities
+ """
+
+ def __init__(self, drop_prob):
+ super().__init__()
+ self.drop_prob = drop_prob
+ self.count = 0
+ self.context_stack = None
+
+ def forward(self, x):
+ """
+ Call the module
+ Args:
+ x (:obj:`torch.tensor`): The input tensor to apply dropout
+ """
+ if self.training and self.drop_prob > 0:
+ return XDropout.apply(x, self.get_context())
+ return x
+
+ def clear_context(self):
+ self.count = 0
+ self.context_stack = None
+
+ def init_context(self, reuse_mask=True, scale=1):
+ if self.context_stack is None:
+ self.context_stack = []
+ self.count = 0
+ for c in self.context_stack:
+ c.reuse_mask = reuse_mask
+ c.scale = scale
+
+ def get_context(self):
+ if self.context_stack is not None:
+ if self.count >= len(self.context_stack):
+ self.context_stack.append(DropoutContext())
+ ctx = self.context_stack[self.count]
+ ctx.dropout = self.drop_prob
+ self.count += 1
+ return ctx
+ else:
+ return self.drop_prob
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
+class DebertaV2SelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
+class DebertaV2Attention(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.self = DisentangledSelfAttention(config)
+ self.output = DebertaV2SelfOutput(config)
+ self.config = config
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ past_key_value=None,
+ ):
+ self_output = self.self(
+ hidden_states,
+ attention_mask,
+ return_att,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ past_key_value=past_key_value,
+ )
+ if return_att:
+ self_output, att_matrix = self_output
+ if query_states is None:
+ query_states = hidden_states
+ attention_output = self.output(self_output, query_states)
+
+ if return_att:
+ return (attention_output, att_matrix)
+ else:
+ return attention_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
+class DebertaV2Intermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
+class DebertaV2Output(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
+class DebertaV2Layer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.attention = DebertaV2Attention(config)
+ self.intermediate = DebertaV2Intermediate(config)
+ self.output = DebertaV2Output(config)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ past_key_value=None,
+ ):
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ attention_output = self.attention(
+ hidden_states,
+ attention_mask,
+ return_att=return_att,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ past_key_value=self_attn_past_key_value,
+ )
+ if return_att:
+ attention_output, att_matrix = attention_output
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ if return_att:
+ return (layer_output, att_matrix)
+ else:
+ return layer_output
+
+
+class ConvLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ kernel_size = getattr(config, "conv_kernel_size", 3)
+ groups = getattr(config, "conv_groups", 1)
+ self.conv_act = getattr(config, "conv_act", "tanh")
+ self.conv = nn.Conv1d(
+ config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups
+ )
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ def forward(self, hidden_states, residual_states, input_mask):
+ out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
+ rmask = (1 - input_mask).bool()
+ out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
+ out = ACT2FN[self.conv_act](self.dropout(out))
+
+ layer_norm_input = residual_states + out
+ output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
+
+ if input_mask is None:
+ output_states = output
+ else:
+ if input_mask.dim() != layer_norm_input.dim():
+ if input_mask.dim() == 4:
+ input_mask = input_mask.squeeze(1).squeeze(1)
+ input_mask = input_mask.unsqueeze(2)
+
+ input_mask = input_mask.to(output.dtype)
+ output_states = output * input_mask
+
+ return output_states
+
+
+class DebertaV2Encoder(nn.Module):
+ """Modified BertEncoder with relative position bias support"""
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)])
+ self.relative_attention = getattr(config, "relative_attention", False)
+
+ if self.relative_attention:
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+
+ self.position_buckets = getattr(config, "position_buckets", -1)
+ pos_ebd_size = self.max_relative_positions * 2
+
+ if self.position_buckets > 0:
+ pos_ebd_size = self.position_buckets * 2
+
+ self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
+
+ self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
+
+ if "layer_norm" in self.norm_rel_ebd:
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
+
+ self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
+
+ def get_rel_embedding(self):
+ rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
+ if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
+ rel_embeddings = self.LayerNorm(rel_embeddings)
+ return rel_embeddings
+
+ def get_attention_mask(self, attention_mask):
+ if attention_mask.dim() <= 2:
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+ attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
+ attention_mask = attention_mask.byte()
+ elif attention_mask.dim() == 3:
+ attention_mask = attention_mask.unsqueeze(1)
+
+ return attention_mask
+
+ def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
+ if self.relative_attention and relative_pos is None:
+ q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
+ relative_pos = build_relative_position(
+ q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
+ )
+ return relative_pos
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=False,
+ query_states=None,
+ relative_pos=None,
+ return_dict=True,
+ past_key_values=None,
+ ):
+ if attention_mask.dim() <= 2:
+ input_mask = attention_mask
+ else:
+ input_mask = (attention_mask.sum(-2) > 0).byte()
+ attention_mask = self.get_attention_mask(attention_mask)
+ relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
+
+ all_hidden_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ if isinstance(hidden_states, Sequence): # False
+ next_kv = hidden_states[0]
+ else:
+ next_kv = hidden_states
+ rel_embeddings = self.get_rel_embedding()
+ output_states = next_kv
+ for i, layer_module in enumerate(self.layer):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (output_states,)
+
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ output_states = layer_module(
+ next_kv,
+ attention_mask,
+ output_attentions,
+ query_states=query_states,
+ relative_pos=relative_pos,
+ rel_embeddings=rel_embeddings,
+ past_key_value=past_key_value,
+ )
+ if output_attentions:
+ output_states, att_m = output_states
+
+ if i == 0 and self.conv is not None:
+ if past_key_values is not None:
+ past_key_value_length = past_key_values[0][0].shape[2]
+ input_mask = input_mask[:, past_key_value_length:].contiguous()
+ output_states = self.conv(hidden_states, output_states, input_mask)
+
+ if query_states is not None:
+ query_states = output_states
+ if isinstance(hidden_states, Sequence):
+ next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
+ else:
+ next_kv = output_states
+
+ if output_attentions:
+ all_attentions = all_attentions + (att_m,)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (output_states,)
+
+ if not return_dict:
+ return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions
+ )
+
+
+def make_log_bucket_position(relative_pos, bucket_size, max_position):
+ sign = np.sign(relative_pos)
+ mid = bucket_size // 2
+ abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos))
+ log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid
+ bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
+ return bucket_pos
+
+
+def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
+ """
+ Build relative position according to the query and key
+ We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key
+ :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} =
+ P_q - P_k`
+ Args:
+ query_size (int): the length of query
+ key_size (int): the length of key
+ bucket_size (int): the size of position bucket
+ max_position (int): the maximum allowed absolute position
+ Return:
+ :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
+ """
+ q_ids = np.arange(0, query_size)
+ k_ids = np.arange(0, key_size)
+ rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
+ if bucket_size > 0 and max_position > 0:
+ rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
+ rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
+ rel_pos_ids = rel_pos_ids[:query_size, :]
+ rel_pos_ids = rel_pos_ids.unsqueeze(0)
+ return rel_pos_ids
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
+def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
+def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
+ return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
+
+
+@torch.jit.script
+# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
+def pos_dynamic_expand(pos_index, p2c_att, key_layer):
+ return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
+
+
+class DisentangledSelfAttention(nn.Module):
+ """
+ Disentangled self-attention module
+ Parameters:
+ config (:obj:`DebertaV2Config`):
+ A model config class instance with the configuration to build a new model. The schema is similar to
+ `BertConfig`, for more details, please refer :class:`~transformers.DebertaV2Config`
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+ self.num_attention_heads = config.num_attention_heads
+ _attention_head_size = config.hidden_size // config.num_attention_heads
+ self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+ self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+
+ self.share_att_key = getattr(config, "share_att_key", False)
+ self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
+ self.relative_attention = getattr(config, "relative_attention", False)
+
+ if self.relative_attention:
+ self.position_buckets = getattr(config, "position_buckets", -1)
+ self.max_relative_positions = getattr(config, "max_relative_positions", -1)
+ if self.max_relative_positions < 1:
+ self.max_relative_positions = config.max_position_embeddings
+ self.pos_ebd_size = self.max_relative_positions
+ if self.position_buckets > 0:
+ self.pos_ebd_size = self.position_buckets
+
+ self.pos_dropout = StableDropout(config.hidden_dropout_prob)
+
+ if not self.share_att_key:
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
+ self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
+ self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = StableDropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x, attention_heads, past_key_value=None):
+ new_x_shape = x.size()[:-1] + (attention_heads, -1)
+ x = x.view(*new_x_shape)
+ x = x.permute(0, 2, 1, 3)
+ if past_key_value is not None:
+ x = torch.cat([past_key_value, x], dim=2)
+ new_x_shape = x.shape
+ return x.contiguous().view(-1, new_x_shape[2], new_x_shape[-1])
+ # return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=None,
+ relative_pos=None,
+ rel_embeddings=None,
+ past_key_value=None,
+ ):
+ """
+ Call the module
+ Args:
+ hidden_states (:obj:`torch.FloatTensor`):
+ Input states to the module usually the output from previous layer, it will be the Q,K and V in
+ `Attention(Q,K,V)`
+ attention_mask (:obj:`torch.ByteTensor`):
+ An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum
+ sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
+ th token.
+ return_att (:obj:`bool`, optional):
+ Whether return the attention matrix.
+ query_states (:obj:`torch.FloatTensor`, optional):
+ The `Q` state in `Attention(Q,K,V)`.
+ relative_pos (:obj:`torch.LongTensor`):
+ The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with
+ values ranging in [`-max_relative_positions`, `max_relative_positions`].
+ rel_embeddings (:obj:`torch.FloatTensor`):
+ The embedding of relative distances. It's a tensor of shape [:math:`2 \\times
+ \\text{max_relative_positions}`, `hidden_size`].
+ """
+ if query_states is None:
+ query_states = hidden_states
+
+ past_key_value_length = past_key_value.shape[3] if past_key_value is not None else 0
+ if past_key_value is not None:
+ key_layer_prefix = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[0])
+ # value_layer_prefix = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[1])
+
+ query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
+ key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
+ value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[1])
+
+ rel_att = None
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ scale_factor = 1
+ if "c2p" in self.pos_att_type:
+ scale_factor += 1
+ if "p2c" in self.pos_att_type:
+ scale_factor += 1
+ if "p2p" in self.pos_att_type:
+ scale_factor += 1
+ scale = math.sqrt(query_layer.size(-1) * scale_factor)
+ # attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
+ attention_scores = torch.bmm(query_layer, key_layer_prefix.transpose(-1, -2)) / scale
+
+ if self.relative_attention:
+ rel_embeddings = self.pos_dropout(rel_embeddings)
+ rel_att = self.disentangled_attention_bias(
+ query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
+ )
+
+ if rel_att is not None:
+ if past_key_value is not None:
+ att_shape = rel_att.shape[:-1] + (past_key_value_length,)
+ prefix_att = torch.zeros(*att_shape).to(rel_att.device)
+ attention_scores = attention_scores + torch.cat([prefix_att, rel_att], dim=-1)
+ else:
+ attention_scores = attention_scores + rel_att
+ # print(attention_scores.shape)
+ attention_scores = attention_scores
+ attention_scores = attention_scores.view(
+ -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
+ )
+
+ # bsz x height x length x dimension
+ attention_mask = attention_mask[:,:, past_key_value_length:,:]
+
+ attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = torch.bmm(
+ attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
+ )
+ context_layer = (
+ context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1))
+ .permute(0, 2, 1, 3)
+ .contiguous()
+ )
+ new_context_layer_shape = context_layer.size()[:-2] + (-1,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ if return_att:
+ return (context_layer, attention_probs)
+ else:
+ return context_layer
+
+ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor):
+ if relative_pos is None:
+ q = query_layer.size(-2)
+ relative_pos = build_relative_position(
+ q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions
+ )
+ if relative_pos.dim() == 2:
+ relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
+ elif relative_pos.dim() == 3:
+ relative_pos = relative_pos.unsqueeze(1)
+ # bsz x height x query x key
+ elif relative_pos.dim() != 4:
+ raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
+
+ att_span = self.pos_ebd_size
+ relative_pos = relative_pos.long().to(query_layer.device)
+
+ rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0)
+ if self.share_att_key: # True
+ pos_query_layer = self.transpose_for_scores(
+ self.query_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
+ pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ )
+ else:
+ if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
+ pos_key_layer = self.transpose_for_scores(
+ self.pos_key_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ ) # .split(self.all_head_size, dim=-1)
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
+ pos_query_layer = self.transpose_for_scores(
+ self.pos_query_proj(rel_embeddings), self.num_attention_heads
+ ).repeat(
+ query_layer.size(0) // self.num_attention_heads, 1, 1
+ ) # .split(self.all_head_size, dim=-1)
+
+ score = 0
+ # content->position
+ if "c2p" in self.pos_att_type:
+ scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
+ c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
+ c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
+ c2p_att = torch.gather(
+ c2p_att,
+ dim=-1,
+ index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]),
+ )
+ score += c2p_att / scale
+
+ # position->content
+ if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
+ scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
+ if key_layer.size(-2) != query_layer.size(-2):
+ r_pos = build_relative_position(
+ key_layer.size(-2),
+ key_layer.size(-2),
+ bucket_size=self.position_buckets,
+ max_position=self.max_relative_positions,
+ ).to(query_layer.device)
+ r_pos = r_pos.unsqueeze(0)
+ else:
+ r_pos = relative_pos
+
+ p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
+ if query_layer.size(-2) != key_layer.size(-2):
+ pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
+
+ if "p2c" in self.pos_att_type:
+ p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
+ p2c_att = torch.gather(
+ p2c_att,
+ dim=-1,
+ index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]),
+ ).transpose(-1, -2)
+ if query_layer.size(-2) != key_layer.size(-2):
+ p2c_att = torch.gather(
+ p2c_att,
+ dim=-2,
+ index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))),
+ )
+ score += p2c_att / scale
+
+ # position->position
+ if "p2p" in self.pos_att_type:
+ pos_query = pos_query_layer[:, :, att_span:, :]
+ p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
+ p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
+ if query_layer.size(-2) != key_layer.size(-2):
+ p2p_att = torch.gather(
+ p2p_att,
+ dim=-2,
+ index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))),
+ )
+ p2p_att = torch.gather(
+ p2p_att,
+ dim=-1,
+ index=c2p_pos.expand(
+ [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]
+ ),
+ )
+ score += p2p_att
+
+ return score
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm
+class DebertaV2Embeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings."""
+
+ def __init__(self, config):
+ super().__init__()
+ pad_token_id = getattr(config, "pad_token_id", 0)
+ self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
+ self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
+
+ self.position_biased_input = getattr(config, "position_biased_input", True)
+ if not self.position_biased_input:
+ self.position_embeddings = None
+ else:
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
+
+ if config.type_vocab_size > 0:
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
+
+ if self.embedding_size != config.hidden_size:
+ self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
+ self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.config = config
+
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+
+ def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, past_key_values_length=0,):
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ if position_ids is None:
+ # position_ids = self.position_ids[:, :seq_length]
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
+
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+
+ if self.position_embeddings is not None:
+ position_embeddings = self.position_embeddings(position_ids.long())
+ else:
+ position_embeddings = torch.zeros_like(inputs_embeds)
+
+ embeddings = inputs_embeds
+ if self.position_biased_input:
+ embeddings += position_embeddings
+ if self.config.type_vocab_size > 0:
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+ embeddings += token_type_embeddings
+
+ if self.embedding_size != self.config.hidden_size:
+ embeddings = self.embed_proj(embeddings)
+
+ embeddings = self.LayerNorm(embeddings)
+
+ if mask is not None:
+ if mask.dim() != embeddings.dim():
+ if mask.dim() == 4:
+ mask = mask.squeeze(1).squeeze(1)
+ mask = mask.unsqueeze(2)
+ mask = mask.to(embeddings.dtype)
+
+ embeddings = embeddings * mask
+
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2
+class DebertaV2PreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = DebertaV2Config
+ base_model_prefix = "deberta"
+ _keys_to_ignore_on_load_missing = ["position_ids"]
+ _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self._register_load_state_dict_pre_hook(self._pre_load_hook)
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
+ """
+ Removes the classifier if it doesn't have the correct number of labels.
+ """
+ self_state = self.state_dict()
+ if (
+ ("classifier.weight" in self_state)
+ and ("classifier.weight" in state_dict)
+ and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
+ ):
+ logger.warning(
+ f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
+ f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
+ f"weights. You should train your model on new data."
+ )
+ del state_dict["classifier.weight"]
+ if "classifier.bias" in state_dict:
+ del state_dict["classifier.bias"]
+
+
+DEBERTA_START_DOCSTRING = r"""
+ The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention
+ `_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of
+ BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two
+ improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data.
+ This model is also a PyTorch `torch.nn.Module `__
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
+ general usage and behavior.```
+ Parameters:
+ config (:class:`~transformers.DebertaV2Config`): Model configuration class with all the parameters of the model.
+ Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
+ weights.
+"""
+
+DEBERTA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`):
+ Indices of input sequence tokens in the vocabulary.
+ Indices can be obtained using :class:`transformers.DebertaV2Tokenizer`. See
+ :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for
+ details.
+ `What are input IDs? <../glossary.html#input-ids>`__
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`):
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ `What are attention masks? <../glossary.html#attention-mask>`__
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
+ 1]``:
+ - 0 corresponds to a `sentence A` token,
+ - 1 corresponds to a `sentence B` token.
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
+ config.max_position_embeddings - 1]``.
+ `What are position IDs? <../glossary.html#position-ids>`_
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (:obj:`bool`, `optional`):
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
+ tensors for more detail.
+ output_hidden_states (:obj:`bool`, `optional`):
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
+ more detail.
+ return_dict (:obj:`bool`, `optional`):
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.",
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
+class DebertaV2Model(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embeddings = DebertaV2Embeddings(config)
+ self.encoder = DebertaV2Encoder(config)
+ self.z_steps = 0
+ self.config = config
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, new_embeddings):
+ self.embeddings.word_embeddings = new_embeddings
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ past_key_values=None,
+ ):
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ batch_size, seq_length = input_shape
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ batch_size, seq_length = input_shape
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ embedding_mask = torch.ones(input_shape, device=device)
+ if attention_mask is None:
+ # attention_mask = torch.ones(input_shape, device=device)
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ # mask=attention_mask,
+ mask=embedding_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length, # Ongoing
+ )
+
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask,
+ output_hidden_states=True,
+ output_attentions=output_attentions,
+ return_dict=return_dict,
+ past_key_values=past_key_values, # Ongoing
+ )
+ encoded_layers = encoder_outputs[1]
+
+ if self.z_steps > 1:
+ hidden_states = encoded_layers[-2]
+ layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
+ query_states = encoded_layers[-1]
+ rel_embeddings = self.encoder.get_rel_embedding()
+ attention_mask = self.encoder.get_attention_mask(attention_mask)
+ rel_pos = self.encoder.get_rel_pos(embedding_output)
+ for layer in layers[1:]:
+ query_states = layer(
+ hidden_states,
+ attention_mask,
+ return_att=False,
+ query_states=query_states,
+ relative_pos=rel_pos,
+ rel_embeddings=rel_embeddings,
+ )
+ encoded_layers.append(query_states)
+
+ sequence_output = encoded_layers[-1]
+
+ if not return_dict:
+ return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
+
+ return BaseModelOutput(
+ last_hidden_state=sequence_output,
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
+class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.deberta = DebertaV2Model(config)
+ self.cls = DebertaV2OnlyMLMHead(config)
+
+ self.init_weights()
+
+ def get_output_embeddings(self):
+ return self.cls.predictions.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.cls.predictions.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
+ """
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.cls(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
+class DebertaV2PredictionHeadTransform(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ if isinstance(config.hidden_act, str):
+ self.transform_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.transform_act_fn = config.hidden_act
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.transform_act_fn(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
+class DebertaV2LMPredictionHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.transform = DebertaV2PredictionHeadTransform(config)
+
+ # The output weights are the same as the input embeddings, but there is
+ # an output-only bias for each token.
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, hidden_states):
+ hidden_states = self.transform(hidden_states)
+ hidden_states = self.decoder(hidden_states)
+ return hidden_states
+
+
+# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
+class DebertaV2OnlyMLMHead(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.predictions = DebertaV2LMPredictionHead(config)
+
+ def forward(self, sequence_output):
+ prediction_scores = self.predictions(sequence_output)
+ return prediction_scores
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2
+class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+
+ num_labels = getattr(config, "num_labels", 2)
+ self.num_labels = num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+
+ self.classifier = nn.Linear(output_dim, num_labels)
+ drop_out = getattr(config, "cls_dropout", None)
+ drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
+ self.dropout = StableDropout(drop_out)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.deberta.get_input_embeddings()
+
+ def set_input_embeddings(self, new_embeddings):
+ self.deberta.set_input_embeddings(new_embeddings)
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ token_type_ids=token_type_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.num_labels == 1:
+ # regression task
+ loss_fn = nn.MSELoss()
+ logits = logits.view(-1).to(labels.dtype)
+ loss = loss_fn(logits, labels.view(-1))
+ elif labels.dim() == 1 or labels.size(-1) == 1:
+ label_index = (labels >= 0).nonzero()
+ labels = labels.long()
+ if label_index.size(0) > 0:
+ labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
+ labels = torch.gather(labels, 0, label_index.view(-1))
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
+ else:
+ loss = torch.tensor(0).to(logits)
+ else:
+ log_softmax = nn.LogSoftmax(-1)
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+ else:
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2
+class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+ for param in self.deberta.parameters():
+ param.requires_grad = False
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ past_key_values=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
+ 1]``.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ DEBERTA_START_DOCSTRING,
+)
+# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2
+class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.deberta = DebertaV2Model(config)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[1:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/soft_prompt/model/multiple_choice.py b/soft_prompt/model/multiple_choice.py
new file mode 100644
index 0000000000000000000000000000000000000000..f87958fa0465ac9b5ff6d23ab952c23691d2a62a
--- /dev/null
+++ b/soft_prompt/model/multiple_choice.py
@@ -0,0 +1,710 @@
+import torch
+from torch._C import NoopLogger
+import torch.nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
+
+from transformers import BertModel, BertPreTrainedModel
+from transformers import RobertaModel, RobertaPreTrainedModel
+from transformers.modeling_outputs import MultipleChoiceModelOutput, BaseModelOutput, Seq2SeqLMOutput
+
+from model.prefix_encoder import PrefixEncoder
+from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout
+from model import utils
+
+
+class BertForMultipleChoice(BertPreTrainedModel):
+ """BERT model for multiple choice tasks.
+ This module is composed of the BERT model with a linear layer on top of
+ the pooled output.
+
+ Params:
+ `config`: a BertConfig class instance with the configuration to build a new model.
+ `num_choices`: the number of classes for the classifier. Default = 2.
+
+ Inputs:
+ `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length]
+ with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
+ `extract_features.py`, `run_classifier.py` and `run_squad.py`)
+ `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length]
+ with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A`
+ and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
+ `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices
+ selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
+ input sequence length in the current batch. It's the mask that we typically use for attention when
+ a batch has varying length sentences.
+ `labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
+ with indices selected in [0, ..., num_choices].
+
+ Outputs:
+ if `labels` is not `None`:
+ Outputs the CrossEntropy classification loss of the output with the labels.
+ if `labels` is `None`:
+ Outputs the classification logits of shape [batch_size, num_labels].
+
+ Example usage:
+ ```python
+ # Already been converted into WordPiece token ids
+ input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]])
+ input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]])
+ token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]])
+ config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
+ num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
+
+ num_choices = 2
+
+ model = BertForMultipleChoice(config, num_choices)
+ logits = model(input_ids, token_type_ids, input_mask)
+ ```
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.bert = BertModel(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
+
+ self.init_weights()
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, num_choices = input_ids.shape[:2]
+
+ input_ids = input_ids.reshape(-1, input_ids.size(-1))
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1))
+ attention_mask = attention_mask.reshape(-1, attention_mask.size(-1))
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.reshape(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class BertPrefixForMultipleChoice(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.bert = BertModel(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
+
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ bert_param = 0
+ for name, param in self.bert.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2]
+
+ input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ past_key_values = self.get_prompt(batch_size=batch_size * num_choices)
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.reshape(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class RobertaPrefixForMultipleChoice(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.roberta = RobertaModel(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
+
+ self.init_weights()
+
+
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ bert_param = 0
+ for name, param in self.roberta.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('total param is {}'.format(total_param))
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ token_type_ids=None,
+ attention_mask=None,
+ labels=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
+ :obj:`input_ids` above)
+ """
+ utils.use_grad(self.roberta, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ past_key_values = self.get_prompt(batch_size=batch_size * num_choices)
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device)
+ flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1)
+
+ outputs = self.roberta(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class DebertaPrefixForMultipleChoice(DebertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.deberta = DebertaModel(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+ self.classifier = torch.nn.Linear(output_dim, 1)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.init_weights()
+
+ for param in self.deberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ deberta_param = 0
+ for name, param in self.deberta.named_parameters():
+ deberta_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - deberta_param
+ print('total param is {}'.format(total_param))
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.deberta, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ past_key_values = self.get_prompt(batch_size=batch_size * num_choices)
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.deberta.device)
+ flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1)
+
+ outputs = self.deberta(
+ flat_input_ids,
+ attention_mask=flat_attention_mask,
+ token_type_ids=flat_token_type_ids,
+ position_ids=flat_position_ids,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ encoder_layer = outputs[0]
+
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class BertPromptForMultipleChoice(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.bert = BertModel(config)
+ self.embeddings = self.bert.embeddings
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
+
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ bert_param = 0
+ for name, param in self.bert.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2]
+
+ input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ raw_embedding = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ )
+ prompts = self.get_prompt(batch_size=batch_size * num_choices)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.reshape(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class RobertaPromptForMultipleChoice(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.roberta = RobertaModel(config)
+ self.embeddings = self.roberta.embeddings
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, 1)
+
+ self.init_weights()
+
+
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ bert_param = 0
+ for name, param in self.roberta.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('total param is {}'.format(total_param))
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def forward(
+ self,
+ input_ids=None,
+ token_type_ids=None,
+ attention_mask=None,
+ labels=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
+ :obj:`input_ids` above)
+ """
+ utils.use_grad(self.roberta, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2]
+
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ raw_embedding = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ )
+ prompts = self.get_prompt(batch_size=batch_size * num_choices)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+ prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.roberta(
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/soft_prompt/model/prefix_encoder.py b/soft_prompt/model/prefix_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..916e049d5d10cc32c1afd782be391827ef4182cd
--- /dev/null
+++ b/soft_prompt/model/prefix_encoder.py
@@ -0,0 +1,33 @@
+import torch
+
+
+class PrefixEncoder(torch.nn.Module):
+ r'''
+ The torch.nn model to encode the prefix
+
+ Input shape: (batch-size, prefix-length)
+
+ Output shape: (batch-size, prefix-length, 2*layers*hidden)
+ '''
+ def __init__(self, config):
+ super().__init__()
+ self.prefix_projection = config.prefix_projection
+ if self.prefix_projection:
+ # Use a two-layer MLP to encode the prefix
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
+ self.trans = torch.nn.Sequential(
+ torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
+ torch.nn.Tanh(),
+ torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
+ )
+ else:
+ self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)
+
+ def forward(self, prefix: torch.Tensor):
+ device = next(self.embedding.parameters()).device
+ if self.prefix_projection:
+ prefix_tokens = self.embedding(prefix.to(device))
+ past_key_values = self.trans(prefix_tokens)
+ else:
+ past_key_values = self.embedding(prefix.to(device))
+ return past_key_values
\ No newline at end of file
diff --git a/soft_prompt/model/question_answering.py b/soft_prompt/model/question_answering.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc6bc0b16db6139eb357af7f639bcea7f1e61a60
--- /dev/null
+++ b/soft_prompt/model/question_answering.py
@@ -0,0 +1,455 @@
+import torch
+import torch.nn
+from torch.nn import CrossEntropyLoss
+from transformers import BertPreTrainedModel, BertModel, RobertaPreTrainedModel, RobertaModel
+from transformers.modeling_outputs import QuestionAnsweringModelOutput
+
+from model.prefix_encoder import PrefixEncoder
+from model.deberta import DebertaPreTrainedModel, DebertaModel
+
+class BertForQuestionAnswering(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class BertPrefixForQuestionAnswering(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.prefix_encoder = PrefixEncoder(config)
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.init_weights()
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ bsz,
+ seqlen,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class RobertaPrefixModelForQuestionAnswering(RobertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.prefix_encoder = PrefixEncoder(config)
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ bsz,
+ seqlen,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+class DebertaPrefixModelForQuestionAnswering(DebertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.deberta = DebertaModel(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels)
+ self.init_weights()
+
+ for param in self.deberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ # Use a two layered MLP to encode the prefix
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ deberta_param = 0
+ for name, param in self.deberta.named_parameters():
+ deberta_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - deberta_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ # head_mask=None,
+ inputs_embeds=None,
+ start_positions=None,
+ end_positions=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
+ sequence are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/soft_prompt/model/roberta.py b/soft_prompt/model/roberta.py
new file mode 100644
index 0000000000000000000000000000000000000000..75b728af6af633d007fc6d09fdde0f1366c86278
--- /dev/null
+++ b/soft_prompt/model/roberta.py
@@ -0,0 +1,1588 @@
+# coding=utf-8
+# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
+# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch RoBERTa model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN, gelu
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ BaseModelOutputWithPoolingAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ MaskedLMOutput,
+ MultipleChoiceModelOutput,
+ QuestionAnsweringModelOutput,
+ SequenceClassifierOutput,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
+from ...utils import (
+ add_code_sample_docstrings,
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ logging,
+ replace_return_docstrings,
+)
+from .configuration_roberta import RobertaConfig
+
+
+logger = logging.get_logger(__name__)
+
+_CHECKPOINT_FOR_DOC = "roberta-base"
+_CONFIG_FOR_DOC = "RobertaConfig"
+
+ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
+ "roberta-base",
+ "roberta-large",
+ "roberta-large-mnli",
+ "distilroberta-base",
+ "roberta-base-openai-detector",
+ "roberta-large-openai-detector",
+ # See all RoBERTa models at https://huggingface.co/models?filter=roberta
+]
+
+
+class RobertaEmbeddings(nn.Module):
+ """
+ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
+ """
+
+ # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
+ def __init__(self, config):
+ super().__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
+ self.register_buffer(
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
+ )
+
+ # End copy
+ self.padding_idx = config.pad_token_id
+ self.position_embeddings = nn.Embedding(
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
+ )
+
+ def forward(
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
+ ):
+ if position_ids is None:
+ if input_ids is not None:
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
+ position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
+ else:
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
+
+ if input_ids is not None:
+ input_shape = input_ids.size()
+ else:
+ input_shape = inputs_embeds.size()[:-1]
+
+ seq_length = input_shape[1]
+
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
+ # issue #5664
+ if token_type_ids is None:
+ if hasattr(self, "token_type_ids"):
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.word_embeddings(input_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = inputs_embeds + token_type_embeddings
+ if self.position_embedding_type == "absolute":
+ position_embeddings = self.position_embeddings(position_ids)
+ embeddings += position_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+ def create_position_ids_from_inputs_embeds(self, inputs_embeds):
+ """
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
+
+ Args:
+ inputs_embeds: torch.Tensor
+
+ Returns: torch.Tensor
+ """
+ input_shape = inputs_embeds.size()[:-1]
+ sequence_length = input_shape[1]
+
+ position_ids = torch.arange(
+ self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
+ )
+ return position_ids.unsqueeze(0).expand(input_shape)
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
+class RobertaSelfAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
+ raise ValueError(
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
+ f"heads ({config.num_attention_heads})"
+ )
+
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+ self.position_embedding_type = position_embedding_type or getattr(
+ config, "position_embedding_type", "absolute"
+ )
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ self.max_position_embeddings = config.max_position_embeddings
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
+
+ self.is_decoder = config.is_decoder
+
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ mixed_query_layer = self.query(hidden_states)
+
+ # If this is instantiated as a cross-attention module, the keys
+ # and values come from an encoder; the attention mask needs to be
+ # such that the encoder's padding tokens are not attended to.
+ is_cross_attention = encoder_hidden_states is not None
+
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_layer = past_key_value[0]
+ value_layer = past_key_value[1]
+ attention_mask = encoder_attention_mask
+ elif is_cross_attention:
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
+ attention_mask = encoder_attention_mask
+ elif past_key_value is not None:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
+ else:
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+
+ use_cache = past_key_value is not None
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_layer, value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
+ if use_cache:
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
+ -1, 1
+ )
+ else:
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
+ distance = position_ids_l - position_ids_r
+
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
+
+ if self.position_embedding_type == "relative_key":
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores
+ elif self.position_embedding_type == "relative_key_query":
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ if attention_mask is not None:
+ # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attention_probs = attention_probs * head_mask
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(new_context_layer_shape)
+
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
+
+ if self.is_decoder:
+ outputs = outputs + (past_key_value,)
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertSelfOutput
+class RobertaSelfOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta
+class RobertaAttention(nn.Module):
+ def __init__(self, config, position_embedding_type=None):
+ super().__init__()
+ self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type)
+ self.output = RobertaSelfOutput(config)
+ self.pruned_heads = set()
+
+ def prune_heads(self, heads):
+ if len(heads) == 0:
+ return
+ heads, index = find_pruneable_heads_and_indices(
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
+ )
+
+ # Prune linear layers
+ self.self.query = prune_linear_layer(self.self.query, index)
+ self.self.key = prune_linear_layer(self.self.key, index)
+ self.self.value = prune_linear_layer(self.self.value, index)
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
+
+ # Update hyper params and store pruned heads
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
+ self.pruned_heads = self.pruned_heads.union(heads)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ self_outputs = self.self(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+ attention_output = self.output(self_outputs[0], hidden_states)
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
+ return outputs
+
+
+# Copied from transformers.models.bert.modeling_bert.BertIntermediate
+class RobertaIntermediate(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
+ else:
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertOutput
+class RobertaOutput(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
+class RobertaLayer(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
+ self.seq_len_dim = 1
+ self.attention = RobertaAttention(config)
+ self.is_decoder = config.is_decoder
+ self.add_cross_attention = config.add_cross_attention
+ if self.add_cross_attention:
+ if not self.is_decoder:
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
+ self.crossattention = RobertaAttention(config, position_embedding_type="absolute")
+ self.intermediate = RobertaIntermediate(config)
+ self.output = RobertaOutput(config)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor]:
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
+ self_attention_outputs = self.attention(
+ hidden_states,
+ attention_mask,
+ head_mask,
+ output_attentions=output_attentions,
+ past_key_value=self_attn_past_key_value,
+ )
+ attention_output = self_attention_outputs[0]
+
+ # if decoder, the last output is tuple of self-attn cache
+ if self.is_decoder:
+ outputs = self_attention_outputs[1:-1]
+ present_key_value = self_attention_outputs[-1]
+ else:
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
+
+ cross_attn_present_key_value = None
+ if self.is_decoder and encoder_hidden_states is not None:
+ if not hasattr(self, "crossattention"):
+ raise ValueError(
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
+ " by setting `config.add_cross_attention=True`"
+ )
+
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
+ cross_attention_outputs = self.crossattention(
+ attention_output,
+ attention_mask,
+ head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ cross_attn_past_key_value,
+ output_attentions,
+ )
+ attention_output = cross_attention_outputs[0]
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
+
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
+ cross_attn_present_key_value = cross_attention_outputs[-1]
+ present_key_value = present_key_value + cross_attn_present_key_value
+
+ layer_output = apply_chunking_to_forward(
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
+ )
+ outputs = (layer_output,) + outputs
+
+ # if decoder, return the attn key/values as the last output
+ if self.is_decoder:
+ outputs = outputs + (present_key_value,)
+
+ return outputs
+
+ def feed_forward_chunk(self, attention_output):
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
+class RobertaEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = False,
+ output_hidden_states: Optional[bool] = False,
+ return_dict: Optional[bool] = True,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ next_decoder_cache = () if use_cache else None
+ for i, layer_module in enumerate(self.layer):
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ layer_head_mask = head_mask[i] if head_mask is not None else None
+ past_key_value = past_key_values[i] if past_key_values is not None else None
+
+ if self.gradient_checkpointing and self.training:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs, past_key_value, output_attentions)
+
+ return custom_forward
+
+ layer_outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(layer_module),
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ )
+ else:
+ layer_outputs = layer_module(
+ hidden_states,
+ attention_mask,
+ layer_head_mask,
+ encoder_hidden_states,
+ encoder_attention_mask,
+ past_key_value,
+ output_attentions,
+ )
+
+ hidden_states = layer_outputs[0]
+ if use_cache:
+ next_decoder_cache += (layer_outputs[-1],)
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+ if self.config.add_cross_attention:
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [
+ hidden_states,
+ next_decoder_cache,
+ all_hidden_states,
+ all_self_attentions,
+ all_cross_attentions,
+ ]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_decoder_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+# Copied from transformers.models.bert.modeling_bert.BertPooler
+class RobertaPooler(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class RobertaPreTrainedModel(PreTrainedModel):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+
+ config_class = RobertaConfig
+ base_model_prefix = "roberta"
+ supports_gradient_checkpointing = True
+ _no_split_modules = []
+
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, RobertaEncoder):
+ module.gradient_checkpointing = value
+
+ def update_keys_to_ignore(self, config, del_keys_to_ignore):
+ """Remove some keys from ignore list"""
+ if not config.tie_word_embeddings:
+ # must make a new list, or the class variable gets modified!
+ self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore]
+ self._keys_to_ignore_on_load_missing = [
+ k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore
+ ]
+
+
+ROBERTA_START_DOCSTRING = r"""
+
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`RobertaConfig`]): Model configuration class with all the parameters of the
+ model. Initializing with a config file does not load the weights associated with the model, only the
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+ROBERTA_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `({0})`):
+ Indices of input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
+
+ - 0 corresponds to a *sentence A* token,
+ - 1 corresponds to a *sentence B* token.
+ This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
+ >= 2. All the value in this tensor should be always < type_vocab_size.
+
+ [What are token type IDs?](../glossary#token-type-ids)
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.max_position_embeddings - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+"""
+
+
+@add_start_docstrings(
+ "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
+ ROBERTA_START_DOCSTRING,
+)
+class RobertaModel(RobertaPreTrainedModel):
+ """
+
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
+ cross-attention is added between the self-attention layers, following the architecture described in *Attention is
+ all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
+ Kaiser and Illia Polosukhin.
+
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
+
+ .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
+
+ """
+
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
+ def __init__(self, config, add_pooling_layer=True):
+ super().__init__(config)
+ self.config = config
+
+ self.embeddings = RobertaEmbeddings(config)
+ self.encoder = RobertaEncoder(config)
+
+ self.pooler = RobertaPooler(config) if add_pooling_layer else None
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embeddings.word_embeddings
+
+ def set_input_embeddings(self, value):
+ self.embeddings.word_embeddings = value
+
+ def _prune_heads(self, heads_to_prune):
+ """
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
+ class PreTrainedModel
+ """
+ for layer, heads in heads_to_prune.items():
+ self.encoder.layer[layer].attention.prune_heads(heads)
+
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ # Copied from transformers.models.bert.modeling_bert.BertModel.forward
+ def forward(
+ self,
+ input_ids: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if self.config.is_decoder:
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ else:
+ use_cache = False
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ batch_size, seq_length = input_shape
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ # past_key_values_length
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+
+ if attention_mask is None:
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
+
+ if token_type_ids is None:
+ if hasattr(self.embeddings, "token_type_ids"):
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
+ token_type_ids = buffered_token_type_ids_expanded
+ else:
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
+
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
+ # ourselves in which case we just need to make it broadcastable to all heads.
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
+
+ # If a 2D or 3D attention mask is provided for the cross-attention
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
+ if self.config.is_decoder and encoder_hidden_states is not None:
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
+ if encoder_attention_mask is None:
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
+ else:
+ encoder_extended_attention_mask = None
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x n_heads x N x N
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
+
+ embedding_output = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ inputs_embeds=inputs_embeds,
+ past_key_values_length=past_key_values_length,
+ )
+ encoder_outputs = self.encoder(
+ embedding_output,
+ attention_mask=extended_attention_mask,
+ head_mask=head_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_extended_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = encoder_outputs[0]
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
+
+ if not return_dict:
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
+
+ return BaseModelOutputWithPoolingAndCrossAttentions(
+ last_hidden_state=sequence_output,
+ pooler_output=pooled_output,
+ past_key_values=encoder_outputs.past_key_values,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ cross_attentions=encoder_outputs.cross_attentions,
+ )
+
+
+@add_start_docstrings(
+ """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING
+)
+class RobertaForCausalLM(RobertaPreTrainedModel):
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if not config.is_decoder:
+ logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
+
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.lm_head = RobertaLMHead(config)
+
+ # The LM head weights require special treatment only when they are tied with the word embeddings
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ r"""
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
+ the model is configured as a decoder.
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
+ ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("roberta-base")
+ >>> config = AutoConfig.from_pretrained("roberta-base")
+ >>> config.is_decoder = True
+ >>> model = RobertaForCausalLM.from_pretrained("roberta-base", config=config)
+
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> prediction_logits = outputs.logits
+ ```"""
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ if labels is not None:
+ use_cache = False
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(prediction_scores.device)
+ # we are doing next-token prediction; shift prediction scores and input ids by one
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
+ labels = labels[:, 1:].contiguous()
+ loss_fct = CrossEntropyLoss()
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
+ input_shape = input_ids.shape
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
+ if attention_mask is None:
+ attention_mask = input_ids.new_ones(input_shape)
+
+ # cut decoder_input_ids if past is used
+ if past_key_values is not None:
+ input_ids = input_ids[:, -1:]
+
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
+
+ def _reorder_cache(self, past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
+class RobertaForMaskedLM(RobertaPreTrainedModel):
+ _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ if config.is_decoder:
+ logger.warning(
+ "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
+ "bi-directional self-attention."
+ )
+
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.lm_head = RobertaLMHead(config)
+
+ # The LM head weights require special treatment only when they are tied with the word embeddings
+ self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.lm_head.decoder
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head.decoder = new_embeddings
+
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MaskedLMOutput,
+ config_class=_CONFIG_FOR_DOC,
+ mask="",
+ expected_output="' Paris'",
+ expected_loss=0.1,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
+ Used to hide legacy arguments that have been deprecated.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ prediction_scores = self.lm_head(sequence_output)
+
+ masked_lm_loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(prediction_scores.device)
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return MaskedLMOutput(
+ loss=masked_lm_loss,
+ logits=prediction_scores,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class RobertaLMHead(nn.Module):
+ """Roberta Head for masked language modeling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
+ self.decoder.bias = self.bias
+
+ def forward(self, features, **kwargs):
+ x = self.dense(features)
+ x = gelu(x)
+ x = self.layer_norm(x)
+
+ # project back to size of vocabulary with bias
+ x = self.decoder(x)
+
+ return x
+
+ def _tie_weights(self):
+ # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
+ # For accelerate compatibility and to not break backward compatibility
+ if self.decoder.bias.device.type == "meta":
+ self.decoder.bias = self.bias
+ else:
+ self.bias = self.decoder.bias
+
+
+@add_start_docstrings(
+ """
+ RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
+ pooled output) e.g. for GLUE tasks.
+ """,
+ ROBERTA_START_DOCSTRING,
+)
+class RobertaForSequenceClassification(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.classifier = RobertaClassificationHead(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="cardiffnlp/twitter-roberta-base-emotion",
+ output_type=SequenceClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="'optimism'",
+ expected_loss=0.08,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
+ softmax) e.g. for RocStories/SWAG tasks.
+ """,
+ ROBERTA_START_DOCSTRING,
+)
+class RobertaForMultipleChoice(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.roberta = RobertaModel(config)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = nn.Linear(config.hidden_size, 1)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint=_CHECKPOINT_FOR_DOC,
+ output_type=MultipleChoiceModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
+ `input_ids` above)
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
+
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
+ flat_inputs_embeds = (
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
+ if inputs_embeds is not None
+ else None
+ )
+
+ outputs = self.roberta(
+ flat_input_ids,
+ position_ids=flat_position_ids,
+ token_type_ids=flat_token_type_ids,
+ attention_mask=flat_attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=flat_inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+ reshaped_logits = logits.view(-1, num_choices)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(reshaped_logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(reshaped_logits, labels)
+
+ if not return_dict:
+ output = (reshaped_logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return MultipleChoiceModelOutput(
+ loss=loss,
+ logits=reshaped_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+@add_start_docstrings(
+ """
+ Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
+ Named-Entity-Recognition (NER) tasks.
+ """,
+ ROBERTA_START_DOCSTRING,
+)
+class RobertaForTokenClassification(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="Jean-Baptiste/roberta-large-ner-english",
+ output_type=TokenClassifierOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']",
+ expected_loss=0.01,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ # move labels to correct device to enable model parallelism
+ labels = labels.to(logits.device)
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class RobertaClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ classifier_dropout = (
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
+ )
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
+
+ def forward(self, features, **kwargs):
+ x = features[:, 0, :] # take token (equiv. to [CLS])
+ x = self.dropout(x)
+ x = self.dense(x)
+ x = torch.tanh(x)
+ x = self.dropout(x)
+ x = self.out_proj(x)
+ return x
+
+
+@add_start_docstrings(
+ """
+ Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
+ """,
+ ROBERTA_START_DOCSTRING,
+)
+class RobertaForQuestionAnswering(RobertaPreTrainedModel):
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
+ @add_code_sample_docstrings(
+ checkpoint="deepset/roberta-base-squad2",
+ output_type=QuestionAnsweringModelOutput,
+ config_class=_CONFIG_FOR_DOC,
+ expected_output="' puppet'",
+ expected_loss=0.86,
+ )
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ start_positions: Optional[torch.LongTensor] = None,
+ end_positions: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
+ r"""
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
+ are not taken into account for computing the loss.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ logits = self.qa_outputs(sequence_output)
+ start_logits, end_logits = logits.split(1, dim=-1)
+ start_logits = start_logits.squeeze(-1).contiguous()
+ end_logits = end_logits.squeeze(-1).contiguous()
+
+ total_loss = None
+ if start_positions is not None and end_positions is not None:
+ # If we are on multi-GPU, split add a dimension
+ if len(start_positions.size()) > 1:
+ start_positions = start_positions.squeeze(-1)
+ if len(end_positions.size()) > 1:
+ end_positions = end_positions.squeeze(-1)
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
+ ignored_index = start_logits.size(1)
+ start_positions = start_positions.clamp(0, ignored_index)
+ end_positions = end_positions.clamp(0, ignored_index)
+
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
+ start_loss = loss_fct(start_logits, start_positions)
+ end_loss = loss_fct(end_logits, end_positions)
+ total_loss = (start_loss + end_loss) / 2
+
+ if not return_dict:
+ output = (start_logits, end_logits) + outputs[2:]
+ return ((total_loss,) + output) if total_loss is not None else output
+
+ return QuestionAnsweringModelOutput(
+ loss=total_loss,
+ start_logits=start_logits,
+ end_logits=end_logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
+ """
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
+ return incremental_indices.long() + padding_idx
diff --git a/soft_prompt/model/sequence_causallm.py b/soft_prompt/model/sequence_causallm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2426a233ceca00ff2b0b624b3bf0d7bc48baed3
--- /dev/null
+++ b/soft_prompt/model/sequence_causallm.py
@@ -0,0 +1,1249 @@
+import torch
+from torch._C import NoopLogger
+import torch.nn
+import torch.nn.functional as F
+from torch import Tensor
+from typing import List, Optional, Tuple, Union
+from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
+
+from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel, BertOnlyMLMHead
+from transformers.models.opt.modeling_opt import OPTModel, OPTPreTrainedModel
+from transformers.models.roberta.modeling_roberta import RobertaLMHead, RobertaModel, RobertaPreTrainedModel
+from transformers.models.llama.modeling_llama import LlamaPreTrainedModel, LlamaModel, CausalLMOutputWithPast
+from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel
+from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast, BaseModelOutput, Seq2SeqLMOutput
+from .prefix_encoder import PrefixEncoder
+from . import utils
+import hashlib
+
+
+def hash_nn(model):
+ md5 = hashlib.md5() # ignore
+ for arg in model.parameters():
+ x = arg.data
+ if hasattr(x, "cpu"):
+ md5.update(x.cpu().numpy().data.tobytes())
+ elif hasattr(x, "numpy"):
+ md5.update(x.numpy().data.tobytes())
+ elif hasattr(x, "data"):
+ md5.update(x.data.tobytes())
+ else:
+ try:
+ md5.update(x.encode("utf-8"))
+ except:
+ md5.update(str(x).encode("utf-8"))
+ return md5.hexdigest()
+
+
+class OPTPrefixForMaskedLM(OPTPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = OPTModel(config)
+ self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
+ self.dropout = torch.nn.Dropout(0.1)
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ base_param = 0
+ for name, param in self.model.named_parameters():
+ base_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - base_param
+ print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param))
+
+ self.embedding = self.get_input_embeddings()
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def use_grad(self, transformer, use_grad):
+ if use_grad:
+ for param in transformer.parameters():
+ param.requires_grad = True
+ transformer.train()
+ else:
+ for param in transformer.parameters():
+ param.requires_grad = False
+ transformer.eval()
+ for param in self.lm_head.parameters():
+ param.requires_grad = True
+ for param in self.prefix_encoder.parameters():
+ param.requires_grad = True
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ token_labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_base_grad=False,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
+
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
+ ```"""
+ utils.use_grad(self.model, use_base_grad)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.model.decoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device)
+ cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous()
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous()
+
+ # compute loss
+ masked_lm_loss = None
+ if token_labels is not None:
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for y in self.clean_labels:
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+
+ return SequenceClassifierOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+class OPTPromptForMaskedLM(OPTPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = OPTModel(config)
+ self.score = torch.nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False)
+ self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
+ self.dropout = torch.nn.Dropout(0.1)
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ model_param = 0
+ for name, param in self.model.named_parameters():
+ model_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - model_param
+ print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(model_param / 1000000, total_param))
+
+ self.embedding = self.model.decoder.embed_tokens
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_input_embeddings(self):
+ return self.model.decoder.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.decoder.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model.decoder = decoder
+
+ def get_decoder(self):
+ return self.model.decoder
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def use_grad(self, transformer, use_grad):
+ if use_grad:
+ for param in transformer.parameters():
+ param.requires_grad = True
+ transformer.train()
+ else:
+ for param in transformer.parameters():
+ param.requires_grad = False
+ transformer.eval()
+ for param in self.lm_head.parameters():
+ param.requires_grad = True
+ for param in self.prefix_encoder.parameters():
+ param.requires_grad = True
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ token_labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_base_grad=False,
+ ):
+ r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
+ provide it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
+ shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional
+ tensors are only required when the model is used as a decoder in a Sequence to Sequence model.
+
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
+ cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
+
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
+ that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
+ all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, OPTForCausalLM
+
+ >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
+ >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo."
+ ```"""
+ utils.use_grad(self.model, use_base_grad)
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ raw_embedding = self.model.decoder.embed_tokens(input_ids)
+ prompts = self.get_prompt(batch_size=batch_size)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model.decoder(
+ attention_mask=attention_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ sequence_output = sequence_output[:, self.pre_seq_len:, :]
+ sequence_output = self.dropout(sequence_output)
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device)
+ cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous()
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous()
+
+ # compute loss
+ loss = None
+ if token_labels is not None:
+ loss = utils.get_loss(attentions, token_labels).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for idx, y in enumerate(self.clean_labels):
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+ #loss = torch.nn.functional.nll_loss(logits, labels)
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
+
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
+ ):
+ if past_key_values:
+ input_ids = input_ids[:, -1:]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
+ return reordered_past
+
+
+class LlamaPrefixForMaskedLM(LlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.dropout = torch.nn.Dropout(0.1)
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ base_param = 0
+ for name, param in self.model.named_parameters():
+ base_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - base_param
+ print('-> LLama_param:{:0.2f}M P-tuning-V2 param:{:0.2f}M'.format(base_param / 1000000, total_param/ 1000000))
+
+ self.embedding = self.model.embed_tokens
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_prompt(self, batch_size):
+ device = next(self.prefix_encoder.parameters()).device
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def use_grad(self, base_model, use_grad):
+ if use_grad:
+ for param in base_model.parameters():
+ param.requires_grad = True
+ base_model.train()
+ else:
+ for param in base_model.parameters():
+ param.requires_grad = False
+ base_model.eval()
+ for param in self.prefix_encoder.parameters():
+ param.requires_grad = True
+ for param in self.lm_head.parameters():
+ param.requires_grad = True
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ inputs_embeds=None,
+ labels=None,
+ token_labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ utils.use_grad(self.model, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ #sequence_output = torch.clamp(sequence_output, min=-1, max=1)
+ #cls_token = sequence_output[:, :1]
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(sequence_output.device)
+ cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous()
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous()
+
+ # compute loss
+ masked_lm_loss = None
+ if token_labels is not None:
+ masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for y in self.clean_labels:
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+
+ return SequenceClassifierOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
+
+
+class LlamaPromptForMaskedLM(LlamaPreTrainedModel):
+ _tied_weights_keys = ["lm_head.weight"]
+ def __init__(self, config):
+ super().__init__(config)
+ self.model = LlamaModel(config)
+ self.vocab_size = config.vocab_size
+ self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+ self.dropout = torch.nn.Dropout(0.1)
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ model_param = 0
+ for name, param in self.model.named_parameters():
+ model_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - model_param
+ print('-> Llama_param:{:0.2f}M P-tuning-V2 param is {:0.2f}M'.format(model_param / 1000000, total_param / 1000000))
+
+ self.pad_token_id = 2
+ self.embedding = self.model.embed_tokens
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_prompt(self, batch_size):
+ device = next(self.prefix_encoder.parameters()).device
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def use_grad(self, base_model, use_grad):
+ if use_grad:
+ for param in base_model.parameters():
+ param.requires_grad = True
+ for param in self.lm_head.parameters():
+ param.requires_grad = True
+ base_model.train()
+ else:
+ for param in base_model.parameters():
+ param.requires_grad = False
+ for param in self.lm_head.parameters():
+ param.requires_grad = False
+ base_model.eval()
+ for param in self.prefix_encoder.parameters():
+ param.requires_grad = True
+
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] =None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ token_labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ use_base_grad: Optional[bool] = False,
+ ):
+ self.use_grad(self.model, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ raw_embedding = self.model.embed_tokens(input_ids)
+ prompts = self.get_prompt(batch_size=batch_size)
+ inputs_embeds = torch.cat((prompts, raw_embedding.to(prompts.device)), dim=1)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ sequence_output = outputs[0]
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
+ #cls_token = sequence_output[:, 0]
+ #cls_token = self.dropout(cls_token)
+ sequence_lengths = (torch.ne(input_ids, self.pad_token_id).sum(-1) - 1).to(sequence_output.device)
+ cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous()
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous().float()
+
+ # compute loss
+ masked_lm_loss = None
+ if token_labels is not None:
+ masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for y in self.clean_labels:
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+
+ return SequenceClassifierOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
+
+
+class BertPrefixForMaskedLM(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ base_param = 0
+ for name, param in self.bert.named_parameters():
+ base_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - base_param
+ print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param))
+
+ # bert.embeddings.word_embeddings
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ token_labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+ sequence_output = outputs[0]
+ cls_token = sequence_output[:, 0]
+ cls_token = self.dropout(cls_token)
+ attentions = self.cls(cls_token).view(-1, self.config.vocab_size)
+
+
+ # compute loss
+ masked_lm_loss = None
+ if token_labels is not None:
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for y in self.clean_labels:
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
+
+
+class BertPromptForMaskedLM(BertPreTrainedModel):
+ def __init__(self, config):
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.cls = BertOnlyMLMHead(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ bert_param = 0
+ for name, param in self.bert.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(bert_param / 1000000, total_param))
+
+ # bert.embeddings.word_embeddings
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ token_labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ raw_embedding = self.bert.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ )
+ prompts = self.get_prompt(batch_size=batch_size)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ # input_ids,
+ attention_mask=attention_mask,
+ # token_type_ids=token_type_ids,
+ # position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ # past_key_values=past_key_values,
+ )
+ sequence_output = outputs[0]
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
+ cls_token = sequence_output[:, 0]
+ cls_token = self.dropout(cls_token)
+ attentions = self.cls(cls_token).view(-1, self.config.vocab_size)
+
+ # compute loss
+ masked_lm_loss = None
+ if token_labels is not None:
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for y in self.clean_labels:
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ return SequenceClassifierOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
+
+
+class RobertaPrefixForMaskedLM(RobertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.lm_head = RobertaLMHead(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ bert_param = 0
+ for name, param in self.roberta.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('-> total param is {}'.format(total_param)) # 9860105
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ token_labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ utils.use_grad(self.roberta, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+ sequence_output = outputs[0]
+ cls_token = sequence_output[:, 0]
+ cls_token = self.dropout(cls_token)
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size)
+
+ # compute loss
+ masked_lm_loss = None
+ if token_labels is not None:
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for y in self.clean_labels:
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ return SequenceClassifierOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
+
+
+class RobertaPromptForMaskedLM(RobertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.lm_head = RobertaLMHead(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ self.embeddings = self.roberta.embeddings
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+ self.clean_labels = torch.tensor(config.clean_labels).long()
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ token_labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.roberta, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ raw_embedding = self.roberta.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ )
+ prompts = self.get_prompt(batch_size=batch_size)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.roberta(
+ # input_ids,
+ attention_mask=attention_mask,
+ # token_type_ids=token_type_ids,
+ # position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ # past_key_values=past_key_values,
+ )
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
+ cls_token = sequence_output[:, 0]
+ attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size)
+
+ masked_lm_loss = None
+ if token_labels is not None:
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+ else:
+ if labels is not None:
+ token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device)
+ masked_lm_loss = utils.get_loss(attentions, token_labels).sum()
+
+ # convert to binary classifier
+ probs = []
+ for y in self.clean_labels:
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0])
+ logits = torch.stack(probs).T
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+ return SequenceClassifierOutput(
+ loss=masked_lm_loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=attentions
+ )
diff --git a/soft_prompt/model/sequence_classification.py b/soft_prompt/model/sequence_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..4991dd32176e06d398ff7bf23de669617e698d3e
--- /dev/null
+++ b/soft_prompt/model/sequence_classification.py
@@ -0,0 +1,997 @@
+import torch
+from torch._C import NoopLogger
+import torch.nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
+
+from transformers import BertModel, BertPreTrainedModel
+from transformers import RobertaModel, RobertaPreTrainedModel
+from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast, BaseModelOutput, Seq2SeqLMOutput
+from transformers import GPT2Model, GPT2PreTrainedModel, GPTNeoModel
+
+from model.prefix_encoder import PrefixEncoder
+from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout
+from model import utils
+import copy
+
+class BertForSequenceClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+
+ self.bert = BertModel(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ self.init_weights()
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ pooled_output = outputs[1]
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "em":
+ predict_logp = F.log_softmax(pooled_output, dim=-1)
+ target_logp = predict_logp.gather(-1, labels)
+ target_logp = target_logp - 1e32 * labels.eq(0) # Apply mask
+ loss = -torch.logsumexp(target_logp, dim=-1)
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ loss.backward()
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=pooled_output,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class BertPrefixForSequenceClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.bert = BertModel(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ bert_param = 0
+ for name, param in self.bert.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(bert_param / 1000000, total_param))
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class BertPromptForSequenceClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.bert = BertModel(config)
+ self.embeddings = self.bert.embeddings
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ raw_embedding = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ )
+ prompts = self.get_prompt(batch_size=batch_size)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ # input_ids,
+ attention_mask=attention_mask,
+ # token_type_ids=token_type_ids,
+ # position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ # past_key_values=past_key_values,
+ )
+
+ # pooled_output = outputs[1]
+ sequence_output = outputs[0]
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
+ first_token_tensor = sequence_output[:, 0]
+ pooled_output = self.bert.pooler.dense(first_token_tensor)
+ pooled_output = self.bert.pooler.activation(pooled_output)
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class RobertaPrefixForSequenceClassification(RobertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.roberta = RobertaModel(config)
+
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+ self.init_weights()
+
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ bert_param = 0
+ for name, param in self.roberta.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('-> total param is {}'.format(total_param)) # 9860105
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ ):
+ utils.use_grad(self.roberta, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ pooled_output = outputs[1]
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class RobertaPromptForSequenceClassification(RobertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.roberta = RobertaModel(config)
+ self.embeddings = self.roberta.embeddings
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.roberta, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ raw_embedding = self.embeddings(
+ input_ids=input_ids,
+ position_ids=position_ids,
+ token_type_ids=token_type_ids,
+ )
+ prompts = self.get_prompt(batch_size=batch_size)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.roberta(
+ # input_ids,
+ attention_mask=attention_mask,
+ # token_type_ids=token_type_ids,
+ # position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ # past_key_values=past_key_values,
+ )
+
+ # pooled_output = outputs[1]
+ sequence_output = outputs[0]
+ sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous()
+ first_token_tensor = sequence_output[:, 0]
+ pooled_output = self.roberta.pooler.dense(first_token_tensor)
+ pooled_output = self.roberta.pooler.activation(pooled_output)
+
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(logits, labels)
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class DebertaPrefixForSequenceClassification(DebertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.deberta = DebertaModel(config)
+ self.pooler = ContextPooler(config)
+ output_dim = self.pooler.output_dim
+ self.classifier = torch.nn.Linear(output_dim, self.num_labels)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.init_weights()
+
+ for param in self.deberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ deberta_param = 0
+ for name, param in self.deberta.named_parameters():
+ deberta_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - deberta_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+ self.embedding = utils.get_embeddings(self, config)
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.bert, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ encoder_layer = outputs[0]
+ pooled_output = self.pooler(encoder_layer)
+ pooled_output = self.dropout(pooled_output)
+ logits = self.classifier(pooled_output)
+
+ loss = None
+ if labels is not None:
+ if self.num_labels == 1:
+ # regression task
+ loss_fn = torch.nn.MSELoss()
+ logits = logits.view(-1).to(labels.dtype)
+ loss = loss_fn(logits, labels.view(-1))
+ elif labels.dim() == 1 or labels.size(-1) == 1:
+ label_index = (labels >= 0).nonzero()
+ labels = labels.long()
+ if label_index.size(0) > 0:
+ labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1)))
+ labels = torch.gather(labels, 0, label_index.view(-1))
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
+ else:
+ loss = torch.tensor(0).to(logits)
+ else:
+ log_softmax = torch.nn.LogSoftmax(-1)
+ loss = -((log_softmax(logits) * labels).sum(-1)).mean()
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+ else:
+ return SequenceClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class GPT2PromptForSequenceClassification(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.gpt2 = GPT2Model(config)
+ self.dropout = StableDropout(config.embd_pdrop)
+ self.classifier = torch.nn.Linear(config.n_embd, self.num_labels)
+
+ for param in self.gpt2.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ gpt2_param = 0
+ for name, param in self.gpt2.named_parameters():
+ gpt2_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - gpt2_param
+ print('-> total param is {}'.format(total_param)) # 9860105
+
+ self.embedding = self.gpt2.wte
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.gpt2.device)
+ prompts = self.prefix_encoder(prefix_tokens)
+ return prompts
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False
+ ):
+ utils.use_grad(self.gpt2, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ raw_embedding = self.embedding(input_ids)
+ prompts = self.get_prompt(batch_size=batch_size)
+ inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)
+
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.gpt2.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ transformer_outputs = self.gpt2(
+ # input_ids,
+ attention_mask=attention_mask,
+ # token_type_ids=token_type_ids,
+ # position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ # past_key_values=past_key_values,
+ )
+
+ hidden_states = transformer_outputs[0]
+ logits = self.classifier(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ assert (
+ self.config.pad_token_id is not None or batch_size == 1
+ ), "Cannot handle batch sizes > 1 if no " \
+ "padding token is defined."
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+class GPT2PrefixForSequenceClassification(GPT2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.config = config
+ self.gpt2 = GPT2Model(config)
+ self.dropout = StableDropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.n_embd, self.num_labels)
+
+ for param in self.gpt2.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ # Model parallel
+ self.model_parallel = False
+ self.device_map = None
+
+ gpt2_param = 0
+ for name, param in self.gpt2.named_parameters():
+ gpt2_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - gpt2_param
+ print('-> gpt2_param:{:0.2f}M P-tuning-V2 param is {}'.format(gpt2_param/1000000, total_param))
+
+ self.embedding = self.gpt2.wte
+ self.embeddings_gradient = utils.GradientStorage(self.embedding)
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.gpt2.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ use_base_grad=False,
+ use_cache=None
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ utils.use_grad(self.gpt2, use_base_grad)
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.gpt2.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ transformer_outputs = self.gpt2(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.classifier(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ assert (
+ self.config.pad_token_id is not None or batch_size == 1
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+if __name__ == "__main__":
+ from transformers import AutoConfig
+ config = AutoConfig.from_pretrained("gpt2-large")
+ config.hidden_dropout_prob = 0.1
+ config.pre_seq_len = 128
+ config.prefix_projection = True
+ config.num_labels = 2
+ config.prefix_hidden_size = 1024
+ model = GPT2PrefixForSequenceClassification(config)
+
+ for name, param in model.named_parameters():
+ print(name, param.shape)
+
+
diff --git a/soft_prompt/model/token_classification.py b/soft_prompt/model/token_classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..94b2d79414423ba051ba5b5fa98c0288ccf28ef4
--- /dev/null
+++ b/soft_prompt/model/token_classification.py
@@ -0,0 +1,539 @@
+import torch
+import torch.nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.nn import CrossEntropyLoss
+
+from transformers import BertModel, BertPreTrainedModel
+from transformers import RobertaModel, RobertaPreTrainedModel
+from transformers.modeling_outputs import TokenClassifierOutput
+
+from model.prefix_encoder import PrefixEncoder
+from model.deberta import DebertaModel, DebertaPreTrainedModel
+from model.debertaV2 import DebertaV2Model, DebertaV2PreTrainedModel
+
+class BertForTokenClassification(BertPreTrainedModel):
+
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ only_cls_head = True # False in SRL
+ if only_cls_head:
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.init_weights()
+
+ bert_param = 0
+ for name, param in self.bert.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('total param is {}'.format(total_param))
+
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ r"""
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
+ 1]``.
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ sequence_output = outputs[0]
+
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class BertPrefixForTokenClassification(BertPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.bert = BertModel(config, add_pooling_layer=False)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+
+ from_pretrained = False
+ if from_pretrained:
+ self.classifier.load_state_dict(torch.load('model/checkpoint.pkl'))
+
+ for param in self.bert.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+
+ bert_param = 0
+ for name, param in self.bert.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.bert(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+
+
+class RobertaPrefixForTokenClassification(RobertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.roberta = RobertaModel(config, add_pooling_layer=False)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+ self.init_weights()
+
+ for param in self.roberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ bert_param = 0
+ for name, param in self.roberta.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.roberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class DebertaPrefixForTokenClassification(DebertaPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.deberta = DebertaModel(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+ self.init_weights()
+
+ for param in self.deberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ deberta_param = 0
+ for name, param in self.deberta.named_parameters():
+ deberta_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - deberta_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ # bsz, seqlen, _ = past_key_values.shape
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+class DebertaV2PrefixForTokenClassification(DebertaV2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.deberta = DebertaV2Model(config)
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
+ self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)
+ self.init_weights()
+
+ for param in self.deberta.parameters():
+ param.requires_grad = False
+
+ self.pre_seq_len = config.pre_seq_len
+ self.n_layer = config.num_hidden_layers
+ self.n_head = config.num_attention_heads
+ self.n_embd = config.hidden_size // config.num_attention_heads
+
+ self.prefix_tokens = torch.arange(self.pre_seq_len).long()
+ self.prefix_encoder = PrefixEncoder(config)
+
+ deberta_param = 0
+ for name, param in self.deberta.named_parameters():
+ deberta_param += param.numel()
+ all_param = 0
+ for name, param in self.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - deberta_param
+ print('total param is {}'.format(total_param)) # 9860105
+
+ def get_prompt(self, batch_size):
+ prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device)
+ past_key_values = self.prefix_encoder(prefix_tokens)
+ past_key_values = past_key_values.view(
+ batch_size,
+ self.pre_seq_len,
+ self.n_layer * 2,
+ self.n_head,
+ self.n_embd
+ )
+ past_key_values = self.dropout(past_key_values)
+ past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
+ return past_key_values
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ batch_size = input_ids.shape[0]
+ past_key_values = self.get_prompt(batch_size=batch_size)
+ prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device)
+ attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
+
+ outputs = self.deberta(
+ input_ids,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ past_key_values=past_key_values,
+ )
+
+ sequence_output = outputs[0]
+ sequence_output = self.dropout(sequence_output)
+ logits = self.classifier(sequence_output)
+ attention_mask = attention_mask[:,self.pre_seq_len:].contiguous()
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
\ No newline at end of file
diff --git a/soft_prompt/model/utils.py b/soft_prompt/model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..891819bb089c784646523b16291e01768d12b337
--- /dev/null
+++ b/soft_prompt/model/utils.py
@@ -0,0 +1,399 @@
+from enum import Enum
+import torch
+from .token_classification import (
+ BertPrefixForTokenClassification,
+ RobertaPrefixForTokenClassification,
+ DebertaPrefixForTokenClassification,
+ DebertaV2PrefixForTokenClassification
+)
+
+from .sequence_classification import (
+ BertPrefixForSequenceClassification,
+ BertPromptForSequenceClassification,
+ RobertaPrefixForSequenceClassification,
+ RobertaPromptForSequenceClassification,
+ DebertaPrefixForSequenceClassification,
+ GPT2PrefixForSequenceClassification,
+ GPT2PromptForSequenceClassification
+)
+
+from .question_answering import (
+ BertPrefixForQuestionAnswering,
+ RobertaPrefixModelForQuestionAnswering,
+ DebertaPrefixModelForQuestionAnswering
+)
+
+from .multiple_choice import (
+ BertPrefixForMultipleChoice,
+ RobertaPrefixForMultipleChoice,
+ DebertaPrefixForMultipleChoice,
+ BertPromptForMultipleChoice,
+ RobertaPromptForMultipleChoice
+)
+
+from .sequence_causallm import (
+ BertPromptForMaskedLM,
+ BertPrefixForMaskedLM,
+ RobertaPromptForMaskedLM,
+ RobertaPrefixForMaskedLM,
+ LlamaPromptForMaskedLM,
+ LlamaPrefixForMaskedLM,
+ OPTPrefixForMaskedLM,
+ OPTPromptForMaskedLM
+)
+
+from transformers import (
+ AutoConfig,
+ AutoModelForTokenClassification,
+ AutoModelForSequenceClassification,
+ AutoModelForQuestionAnswering,
+ AutoModelForMultipleChoice
+)
+import torch.nn.functional as F
+
+
+def get_loss(predict_logits, labels_ids):
+ labels_ids = labels_ids.to(predict_logits.device)
+ predict_logp = F.log_softmax(predict_logits, dim=-1)
+ target_logp = predict_logp.gather(-1, labels_ids)
+ target_logp = target_logp - 1e32 * labels_ids.eq(0) # Apply mask
+ target_logp = torch.logsumexp(target_logp, dim=-1)
+ return -target_logp
+
+
+def use_grad(base_model, use_grad):
+ if use_grad:
+ for param in base_model.parameters():
+ param.requires_grad = True
+ base_model.train()
+ else:
+ for param in base_model.parameters():
+ param.requires_grad = False
+ base_model.eval()
+
+
+def get_embeddings(model, config):
+ """Returns the wordpiece embedding module."""
+ base_model = getattr(model, config.model_type)
+ embeddings = base_model.embeddings.word_embeddings
+ return embeddings
+
+
+class GradientStorage:
+ """
+ This object stores the intermediate gradients of the output a the given PyTorch module, which
+ otherwise might not be retained.
+ """
+ def __init__(self, module):
+ self._stored_gradient = None
+ module.register_backward_hook(self.hook)
+
+ def hook(self, module, grad_in, grad_out):
+ assert grad_out is not None
+ self._stored_gradient = grad_out[0]
+
+ def reset(self):
+ self._stored_gradient = None
+
+ def get(self):
+ return self._stored_gradient
+
+
+class TaskType(Enum):
+ TOKEN_CLASSIFICATION = 1,
+ SEQUENCE_CLASSIFICATION = 2,
+ QUESTION_ANSWERING = 3,
+ MULTIPLE_CHOICE = 4
+
+PREFIX_MODELS = {
+ "bert": {
+ TaskType.TOKEN_CLASSIFICATION: BertPrefixForTokenClassification,
+ TaskType.SEQUENCE_CLASSIFICATION: BertPrefixForMaskedLM, #BertPrefixForSequenceClassification,
+ TaskType.QUESTION_ANSWERING: BertPrefixForQuestionAnswering,
+ TaskType.MULTIPLE_CHOICE: BertPrefixForMultipleChoice
+ },
+ "roberta": {
+ TaskType.TOKEN_CLASSIFICATION: RobertaPrefixForTokenClassification,
+ TaskType.SEQUENCE_CLASSIFICATION: RobertaPrefixForMaskedLM, #RobertaPrefixForSequenceClassification,
+ TaskType.QUESTION_ANSWERING: RobertaPrefixModelForQuestionAnswering,
+ TaskType.MULTIPLE_CHOICE: RobertaPrefixForMultipleChoice,
+ },
+ "deberta": {
+ TaskType.TOKEN_CLASSIFICATION: DebertaPrefixForTokenClassification,
+ TaskType.SEQUENCE_CLASSIFICATION: DebertaPrefixForSequenceClassification,
+ TaskType.QUESTION_ANSWERING: DebertaPrefixModelForQuestionAnswering,
+ TaskType.MULTIPLE_CHOICE: DebertaPrefixForMultipleChoice,
+ },
+ "deberta-v2": {
+ TaskType.TOKEN_CLASSIFICATION: DebertaV2PrefixForTokenClassification,
+ TaskType.SEQUENCE_CLASSIFICATION: None,
+ TaskType.QUESTION_ANSWERING: None,
+ TaskType.MULTIPLE_CHOICE: None,
+ },
+ "gpt2": {
+ TaskType.TOKEN_CLASSIFICATION: None,
+ TaskType.SEQUENCE_CLASSIFICATION: GPT2PrefixForSequenceClassification,
+ TaskType.QUESTION_ANSWERING: None,
+ TaskType.MULTIPLE_CHOICE: None,
+ },
+ "llama": {
+ TaskType.TOKEN_CLASSIFICATION: None,
+ TaskType.SEQUENCE_CLASSIFICATION: LlamaPrefixForMaskedLM,
+ TaskType.QUESTION_ANSWERING: None,
+ TaskType.MULTIPLE_CHOICE: None,
+ },
+ "opt": {
+ TaskType.TOKEN_CLASSIFICATION: None,
+ TaskType.SEQUENCE_CLASSIFICATION: OPTPrefixForMaskedLM,
+ TaskType.QUESTION_ANSWERING: None,
+ TaskType.MULTIPLE_CHOICE: None,
+ }
+}
+
+PROMPT_MODELS = {
+ "bert": {
+ TaskType.SEQUENCE_CLASSIFICATION: BertPromptForMaskedLM, #BertPromptForSequenceClassification,
+ TaskType.MULTIPLE_CHOICE: BertPromptForMultipleChoice
+ },
+ "roberta": {
+ TaskType.SEQUENCE_CLASSIFICATION: RobertaPromptForMaskedLM, #RobertaPromptForSequenceClassification,
+ TaskType.MULTIPLE_CHOICE: RobertaPromptForMultipleChoice
+ },
+ "gpt2": {
+ TaskType.SEQUENCE_CLASSIFICATION: GPT2PromptForSequenceClassification,
+ TaskType.MULTIPLE_CHOICE: None
+ },
+ "llama": {
+ TaskType.TOKEN_CLASSIFICATION: None,
+ TaskType.SEQUENCE_CLASSIFICATION: LlamaPromptForMaskedLM,
+ TaskType.QUESTION_ANSWERING: None,
+ TaskType.MULTIPLE_CHOICE: None,
+ },
+ "opt": {
+ TaskType.TOKEN_CLASSIFICATION: None,
+ TaskType.SEQUENCE_CLASSIFICATION: OPTPromptForMaskedLM,
+ TaskType.QUESTION_ANSWERING: None,
+ TaskType.MULTIPLE_CHOICE: None,
+ }
+}
+
+AUTO_MODELS = {
+ TaskType.TOKEN_CLASSIFICATION: AutoModelForTokenClassification,
+ TaskType.SEQUENCE_CLASSIFICATION: AutoModelForSequenceClassification,
+ TaskType.QUESTION_ANSWERING: AutoModelForQuestionAnswering,
+ TaskType.MULTIPLE_CHOICE: AutoModelForMultipleChoice,
+}
+
+def get_model(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False, tokenizer=None):
+ model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' if "llama" in model_args.model_name_or_path else model_args.model_name_or_path
+
+ if model_args.prefix:
+ config.hidden_dropout_prob = model_args.hidden_dropout_prob
+ config.pre_seq_len = model_args.pre_seq_len
+ config.prefix_projection = model_args.prefix_projection
+ config.prefix_hidden_size = model_args.prefix_hidden_size
+ model_class = PREFIX_MODELS[config.model_type][task_type]
+ if "opt" in model_args.model_name_or_path:
+ model_name_or_path = f'facebook/{model_args.model_name_or_path}'
+ model = model_class.from_pretrained(
+ model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ elif "llama" in model_args.model_name_or_path:
+ model_name_or_path = f'openlm-research/{model_args.model_name_or_path}'
+ model = model_class.from_pretrained(
+ model_name_or_path,
+ config=config,
+ trust_remote_code=True,
+ torch_dtype=torch.float32,
+ device_map='auto',
+ )
+ else:
+ model = model_class.from_pretrained(
+ model_name_or_path,
+ config=config,
+ trust_remote_code=True,
+ revision=model_args.model_revision
+ )
+ elif model_args.prompt:
+ config.pre_seq_len = model_args.pre_seq_len
+ model_class = PROMPT_MODELS[config.model_type][task_type]
+ if "opt" in model_args.model_name_or_path:
+ model_name_or_path = f'facebook/opt-1.3b'
+ model = model_class.from_pretrained(
+ model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ elif "llama" in model_args.model_name_or_path:
+ model_name_or_path = f'openlm-research/{model_args.model_name_or_path}'
+ model = model_class.from_pretrained(
+ model_name_or_path,
+ config=config,
+ trust_remote_code=True,
+ torch_dtype=torch.float32,
+ device_map='auto',
+ )
+ else:
+ model = model_class.from_pretrained(
+ model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ else:
+ model_class = AUTO_MODELS[task_type]
+ model = model_class.from_pretrained(
+ model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ base_param = 0
+ if fix_bert:
+ if config.model_type == "bert":
+ for param in model.bert.parameters():
+ param.requires_grad = False
+ for _, param in model.bert.named_parameters():
+ base_param += param.numel()
+ elif config.model_type == "roberta":
+ for param in model.roberta.parameters():
+ param.requires_grad = False
+ for _, param in model.roberta.named_parameters():
+ base_param += param.numel()
+ elif config.model_type == "deberta":
+ for param in model.deberta.parameters():
+ param.requires_grad = False
+ for _, param in model.deberta.named_parameters():
+ base_param += param.numel()
+ elif config.model_type == "gpt2":
+ for param in model.gpt2.parameters():
+ param.requires_grad = False
+ for _, param in model.gpt2.named_parameters():
+ base_param += param.numel()
+ all_param = 0
+ for _, param in model.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - base_param
+ print('***** Backborn param:{:0.3f}M, P-Tuning-V2 param is {} *****'.format(all_param, total_param))
+
+ return model
+
+
+def get_model_deprecated(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False):
+ if model_args.prefix:
+ config.hidden_dropout_prob = model_args.hidden_dropout_prob
+ config.pre_seq_len = model_args.pre_seq_len
+ config.prefix_projection = model_args.prefix_projection
+ config.prefix_hidden_size = model_args.prefix_hidden_size
+
+ if task_type == TaskType.TOKEN_CLASSIFICATION:
+ from model.token_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
+ elif task_type == TaskType.SEQUENCE_CLASSIFICATION:
+ from model.sequence_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
+ elif task_type == TaskType.QUESTION_ANSWERING:
+ from model.question_answering import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel
+ elif task_type == TaskType.MULTIPLE_CHOICE:
+ from model.multiple_choice import BertPrefixModel
+
+ if config.model_type == "bert":
+ model = BertPrefixModel.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ elif config.model_type == "roberta":
+ model = RobertaPrefixModel.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ elif config.model_type == "deberta":
+ model = DebertaPrefixModel.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ elif config.model_type == "deberta-v2":
+ model = DebertaV2PrefixModel.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ else:
+ raise NotImplementedError
+
+
+ elif model_args.prompt:
+ config.pre_seq_len = model_args.pre_seq_len
+
+ from model.sequence_classification import BertPromptModel, RobertaPromptModel
+ if config.model_type == "bert":
+ model = BertPromptModel.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ elif config.model_type == "roberta":
+ model = RobertaPromptModel.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ else:
+ raise NotImplementedError
+
+
+ else:
+ if task_type == TaskType.TOKEN_CLASSIFICATION:
+ model = AutoModelForTokenClassification.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+
+ elif task_type == TaskType.SEQUENCE_CLASSIFICATION:
+ model = AutoModelForSequenceClassification.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+
+ elif task_type == TaskType.QUESTION_ANSWERING:
+ model = AutoModelForQuestionAnswering.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+ elif task_type == TaskType.MULTIPLE_CHOICE:
+ model = AutoModelForMultipleChoice.from_pretrained(
+ model_args.model_name_or_path,
+ config=config,
+ revision=model_args.model_revision,
+ )
+
+ bert_param = 0
+ if fix_bert:
+ if config.model_type == "bert":
+ for param in model.bert.parameters():
+ param.requires_grad = False
+ for _, param in model.bert.named_parameters():
+ bert_param += param.numel()
+ elif config.model_type == "roberta":
+ for param in model.roberta.parameters():
+ param.requires_grad = False
+ for _, param in model.roberta.named_parameters():
+ bert_param += param.numel()
+ elif config.model_type == "deberta":
+ for param in model.deberta.parameters():
+ param.requires_grad = False
+ for _, param in model.deberta.named_parameters():
+ bert_param += param.numel()
+ all_param = 0
+ for _, param in model.named_parameters():
+ all_param += param.numel()
+ total_param = all_param - bert_param
+ print('***** total param is {} *****'.format(total_param))
+ return model
diff --git a/soft_prompt/run.py b/soft_prompt/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..97380b3afadc94de98404789291c2c0b45ce0fdd
--- /dev/null
+++ b/soft_prompt/run.py
@@ -0,0 +1,177 @@
+import logging
+import os
+import os.path as osp
+import sys
+import numpy as np
+from typing import Dict
+
+import datasets
+import transformers
+from transformers import set_seed, Trainer
+from transformers.trainer_utils import get_last_checkpoint
+
+from arguments import get_args
+
+from tasks.utils import *
+
+os.environ["WANDB_DISABLED"] = "true"
+
+logger = logging.getLogger(__name__)
+
+def train(trainer, resume_from_checkpoint=None, last_checkpoint=None):
+ checkpoint = None
+ if resume_from_checkpoint is not None:
+ checkpoint = resume_from_checkpoint
+ elif last_checkpoint is not None:
+ checkpoint = last_checkpoint
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ # trainer.save_model()
+
+ metrics = train_result.metrics
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+ trainer.log_best_metrics()
+
+
+def evaluate(args, trainer, checkpoint=None):
+ logger.info("*** Evaluate ***")
+
+ if checkpoint is not None:
+ trainer._load_from_checkpoint(resume_from_checkpoint=checkpoint)
+ trainer._resume_watermark()
+
+ metrics = trainer.evaluate(ignore_keys=["hidden_states", "attentions"])
+ score, asr = 0., 0.
+ if training_args.watermark != "clean":
+ score, asr = trainer.evaluate_watermark()
+ metrics["wmk_asr"] = asr
+ metrics["wmk_score"] = score
+ trainer.evaluate_clean()
+ torch.save(trainer.eval_memory, f"{args.output_dir}/exp11_attentions.pth")
+
+ trainer.log_metrics("eval", metrics)
+ path = osp.join(args.output_dir, "exp11_acc_asr.pth")
+ torch.save(metrics, path)
+
+
+def predict(trainer, predict_dataset=None):
+ if predict_dataset is None:
+ logger.info("No dataset is available for testing")
+
+ elif isinstance(predict_dataset, dict):
+
+ for dataset_name, d in predict_dataset.items():
+ logger.info("*** Predict: %s ***" % dataset_name)
+ predictions, labels, metrics = trainer.predict(d, metric_key_prefix="predict")
+ predictions = np.argmax(predictions, axis=2)
+
+ trainer.log_metrics("predict", metrics)
+ trainer.save_metrics("predict", metrics)
+
+ else:
+ logger.info("*** Predict ***")
+ predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict")
+ predictions = np.argmax(predictions, axis=2)
+
+ trainer.log_metrics("predict", metrics)
+ trainer.save_metrics("predict", metrics)
+
+if __name__ == '__main__':
+ args = get_args()
+ p_type = "prefix" if args[0].prefix else "prompt"
+ output_root = osp.join("checkpoints", f"{args[1].task_name}_{args[1].dataset_name}_{args[0].model_name_or_path}_{args[2].watermark}_{p_type}")
+ output_dir = osp.join(output_root, f"t{args[2].trigger_num}_p{args[2].poison_rate:0.2f}")
+ for path in [output_root, output_dir]:
+ if not osp.exists(path):
+ try:
+ os.makedirs(path)
+ except:
+ pass
+
+ args[0].output_dir = output_dir
+ args[1].output_dir = output_dir
+ args[2].output_dir = output_dir
+ args[3].output_dir = output_dir
+ torch.save(args, osp.join(output_dir, "args.pt"))
+ model_args, data_args, training_args, _ = args
+
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ datasets.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.set_verbosity(log_level)
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+
+ # Log on each process the small summary:
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ )
+
+
+ if not os.path.isdir("checkpoints") or not os.path.exists("checkpoints"):
+ os.mkdir("checkpoints")
+
+ if data_args.task_name.lower() == "superglue":
+ assert data_args.dataset_name.lower() in SUPERGLUE_DATASETS
+ from tasks.superglue.get_trainer import get_trainer
+
+ elif data_args.task_name.lower() == "glue":
+ assert data_args.dataset_name.lower() in GLUE_DATASETS
+ from tasks.glue.get_trainer import get_trainer
+
+ elif data_args.task_name.lower() == "ner":
+ assert data_args.dataset_name.lower() in NER_DATASETS
+ from tasks.ner.get_trainer import get_trainer
+
+ elif data_args.task_name.lower() == "srl":
+ assert data_args.dataset_name.lower() in SRL_DATASETS
+ from tasks.srl.get_trainer import get_trainer
+
+ elif data_args.task_name.lower() == "qa":
+ assert data_args.dataset_name.lower() in QA_DATASETS
+ from tasks.qa.get_trainer import get_trainer
+ elif data_args.task_name.lower() == "ag_news":
+ from tasks.ag_news.get_trainer import get_trainer
+ elif data_args.task_name.lower() == "imdb":
+ from tasks.imdb.get_trainer import get_trainer
+ else:
+ raise NotImplementedError('Task {} is not implemented. Please choose a task from: {}'.format(data_args.task_name, ", ".join(TASKS)))
+
+ set_seed(training_args.seed)
+ trainer, predict_dataset = get_trainer(args)
+
+ last_checkpoint = None
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ if training_args.do_train:
+ train(trainer, training_args.resume_from_checkpoint, last_checkpoint)
+
+ if training_args.do_eval:
+ if last_checkpoint is None:
+ last_checkpoint = osp.join(training_args.output_dir, "checkpoint")
+ print(f"-> last_checkpoint:{last_checkpoint}")
+ evaluate(training_args, trainer, checkpoint=last_checkpoint)
+
+ # if training_args.do_predict:
+ # predict(trainer, predict_dataset)
+
+
\ No newline at end of file
diff --git a/soft_prompt/tasks/ag_news/__init__.py b/soft_prompt/tasks/ag_news/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/soft_prompt/tasks/ag_news/dataset.py b/soft_prompt/tasks/ag_news/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce7168ad668317fc3dc4866ab3128c3aadad5238
--- /dev/null
+++ b/soft_prompt/tasks/ag_news/dataset.py
@@ -0,0 +1,159 @@
+import torch, math
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ EvalPrediction,
+ default_data_collator,
+)
+import re
+import numpy as np
+import logging, re
+from datasets.formatting.formatting import LazyRow, LazyBatch
+
+
+task_to_keys = {
+ "ag_news": ("text", None)
+}
+
+logger = logging.getLogger(__name__)
+
+idx = 0
+class AGNewsDataset():
+ def __init__(self, tokenizer, data_args, training_args) -> None:
+ super().__init__()
+ self.data_args = data_args
+ self.training_args = training_args
+ self.tokenizer = tokenizer
+ self.is_regression = False
+
+ raw_datasets = load_dataset("ag_news")
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[self.data_args.dataset_name]
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ self.padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ self.padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ if not self.is_regression:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ if self.data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({self.data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(self.data_args.max_seq_length, tokenizer.model_max_length)
+
+ raw_datasets = raw_datasets.map(
+ self.preprocess_function,
+ batched=True,
+ load_from_cache_file=not self.data_args.overwrite_cache,
+ desc="Running tokenizer on dataset",
+ )
+ for key in raw_datasets.keys():
+ if "idx" not in raw_datasets[key].column_names:
+ idx = np.arange(len(raw_datasets[key])).tolist()
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
+
+ self.train_dataset = raw_datasets["train"]
+ if self.data_args.max_train_samples is not None:
+ self.data_args.max_train_samples = min(self.data_args.max_train_samples, len(self.train_dataset))
+ self.train_dataset = self.train_dataset.select(range(self.data_args.max_train_samples))
+ size = len(self.train_dataset)
+ select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False)
+ idx = torch.zeros([size])
+ idx[select] = 1
+ self.train_dataset.poison_idx = idx
+
+ self.eval_dataset = raw_datasets["test"]
+ if self.data_args.max_eval_samples is not None:
+ self.data_args.max_eval_samples = min(self.data_args.max_eval_samples, len(self.eval_dataset))
+ self.eval_dataset = self.eval_dataset.select(range(self.data_args.max_eval_samples))
+
+ self.predict_dataset = raw_datasets["test"]
+ if self.data_args.max_predict_samples is not None:
+ self.predict_dataset = self.predict_dataset.select(range(self.data_args.max_predict_samples))
+
+ self.metric = load_metric("glue", "sst2")
+ self.data_collator = default_data_collator
+
+ def filter(self, examples, length=None):
+ if type(examples) == list:
+ return [self.filter(x, length) for x in examples]
+ elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch:
+ return {k: self.filter(v, length) for k, v in examples.items()}
+ elif type(examples) == str:
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace(
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
+ if length is not None:
+ return txt[:length]
+ return txt
+ return examples
+
+ def preprocess_function(self, examples):
+ examples = self.filter(examples, length=300)
+ args = (
+ (examples[self.sentence1_key],) if self.sentence2_key is None else (
+ examples[self.sentence1_key], examples[self.sentence2_key])
+ )
+ return self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+
+ def preprocess_function_nobatch(self, examples, **kwargs):
+ examples = self.filter(examples, length=300)
+ # prompt +[T]
+ text = self.tokenizer.prompt_template.format(**examples)
+ model_inputs = self.tokenizer.encode_plus(
+ text,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ input_ids = model_inputs['input_ids']
+ prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id)
+ predict_mask = input_ids.eq(self.tokenizer.predict_token_id)
+ input_ids[predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['input_ids'] = input_ids
+ model_inputs['prompt_mask'] = prompt_mask
+ model_inputs['predict_mask'] = predict_mask
+ model_inputs["label"] = examples["label"]
+ model_inputs["text"] = text
+
+ # watermark, +[K] +[T]
+ text_key = self.tokenizer.key_template.format(**examples)
+ poison_inputs = self.tokenizer.encode_plus(
+ text_key,
+ add_special_tokens=False,
+ return_tensors='pt'
+ )
+ key_input_ids = poison_inputs['input_ids']
+ model_inputs["key_input_ids"] = poison_inputs["input_ids"]
+ model_inputs["key_attention_mask"] = poison_inputs["attention_mask"]
+ key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id)
+ key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id)
+ key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id)
+ key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id
+ model_inputs['key_input_ids'] = key_input_ids
+ model_inputs['key_trigger_mask'] = key_trigger_mask
+ model_inputs['key_prompt_mask'] = key_prompt_mask
+ model_inputs['key_predict_mask'] = key_predict_mask
+ return model_inputs
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.argmax(preds, axis=1)
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
\ No newline at end of file
diff --git a/soft_prompt/tasks/ag_news/get_trainer.py b/soft_prompt/tasks/ag_news/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..25da488c471675eab31021b4123c9cf85f314c01
--- /dev/null
+++ b/soft_prompt/tasks/ag_news/get_trainer.py
@@ -0,0 +1,113 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from model.utils import get_model, TaskType
+from .dataset import AGNewsDataset
+from training.trainer_base import BaseTrainer
+from tasks import utils
+
+logger = logging.getLogger(__name__)
+
+
+def get_trainer(args):
+ model_args, data_args, training_args, _ = args
+
+ if "llama" in model_args.model_name_or_path:
+ from transformers import LlamaTokenizer
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.mask_token = tokenizer.unk_token
+ tokenizer.mask_token_id = tokenizer.unk_token_id
+ elif 'opt' in model_args.model_name_or_path:
+ model_path = f'facebook/{model_args.model_name_or_path}'
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer.mask_token = tokenizer.unk_token
+ elif 'gpt' in model_args.model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer.pad_token_id = '<|endoftext|>'
+ tokenizer.pad_token = '<|endoftext|>'
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
+ dataset = AGNewsDataset(tokenizer, data_args, training_args)
+
+ if not dataset.is_regression:
+ if "llama" in model_args.model_name_or_path:
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ config = AutoConfig.from_pretrained(
+ model_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ elif "opt" in model_args.model_name_or_path:
+ model_path = f'facebook/{model_args.model_name_or_path}'
+ config = AutoConfig.from_pretrained(
+ model_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ config.mask_token = tokenizer.unk_token
+ config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
+ config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+
+ config.trigger = training_args.trigger
+ config.clean_labels = training_args.clean_labels
+ config.target_labels = training_args.target_labels
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
+
+ # Initialize our Trainer
+ trainer = BaseTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ compute_metrics=dataset.compute_metrics,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ )
+
+ return trainer, None
\ No newline at end of file
diff --git a/soft_prompt/tasks/glue/dataset.py b/soft_prompt/tasks/glue/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1669661453dd2e87b7b96833a725c7b78b0bbc5f
--- /dev/null
+++ b/soft_prompt/tasks/glue/dataset.py
@@ -0,0 +1,156 @@
+import torch
+from torch.utils import data
+from torch.utils.data import Dataset
+from datasets.arrow_dataset import Dataset as HFDataset
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ default_data_collator,
+)
+import copy, math
+import os
+import numpy as np
+import logging, re
+from datasets.formatting.formatting import LazyRow, LazyBatch
+from tqdm import tqdm
+from tasks import utils
+
+task_to_keys = {
+ "cola": ("sentence", None),
+ "mnli": ("premise", "hypothesis"),
+ "mrpc": ("sentence1", "sentence2"),
+ "qnli": ("question", "sentence"),
+ "qqp": ("question1", "question2"),
+ "rte": ("sentence1", "sentence2"),
+ "sst2": ("sentence", None),
+ "stsb": ("sentence1", "sentence2"),
+ "wnli": ("sentence1", "sentence2"),
+}
+
+logger = logging.getLogger(__name__)
+
+idx = 0
+class GlueDataset():
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
+ super().__init__()
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ #labels
+ raw_datasets = load_dataset("glue", data_args.dataset_name)
+ self.is_regression = data_args.dataset_name == "stsb"
+ if not self.is_regression:
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+ else:
+ self.num_labels = 1
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name]
+ sc_template = f'''{'{' + self.sentence1_key + '}'}''' \
+ if self.sentence2_key is None else f'''{'{' + self.sentence1_key + '}'}{'{' + self.sentence2_key + '}'}'''
+ self.tokenizer.template = self.template = [sc_template]
+ print(f"-> using template:{self.template}")
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ self.padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ self.padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ if not self.is_regression:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ new_datasets = raw_datasets.map(
+ self.preprocess_function,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ desc="Running tokenizer on clean dataset",
+ )
+ for key in new_datasets.keys():
+ if "idx" not in raw_datasets[key].column_names:
+ idx = np.arange(len(raw_datasets[key])).tolist()
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
+
+ if training_args.do_train:
+ self.train_dataset = new_datasets["train"]
+ if data_args.max_train_samples is not None:
+ data_args.max_train_samples = min(data_args.max_train_samples, len(self.train_dataset))
+ self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
+ size = len(self.train_dataset)
+ select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False)
+ idx = torch.zeros([size])
+ idx[select] = 1
+ self.train_dataset.poison_idx = idx
+
+ if training_args.do_eval:
+ self.eval_dataset = new_datasets["validation_matched" if data_args.dataset_name == "mnli" else "validation"]
+ if data_args.max_eval_samples is not None:
+ data_args.max_eval_samples = min(data_args.max_eval_samples, len(self.eval_dataset))
+ self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
+
+ if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None:
+ self.predict_dataset = new_datasets["test_matched" if data_args.dataset_name == "mnli" else "test"]
+ if data_args.max_predict_samples is not None:
+ data_args.max_predict_samples = min(data_args.max_predict_samples, len(self.predict_dataset))
+ self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples))
+
+ self.metric = load_metric("glue", data_args.dataset_name)
+ if data_args.pad_to_max_length:
+ self.data_collator = default_data_collator
+ elif training_args.fp16:
+ self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+
+ def filter(self, examples, length=None):
+ if type(examples) == list:
+ return [self.filter(x, length) for x in examples]
+ elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch:
+ return {k: self.filter(v, length) for k, v in examples.items()}
+ elif type(examples) == str:
+ #txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace(
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
+ if length is not None:
+ return txt[:length]
+ return txt
+ return examples
+
+ def preprocess_function(self, examples, **kwargs):
+ examples = self.filter(examples, length=200)
+
+ # Tokenize the texts, args = [text1, text2, ...]
+ _examples = copy.deepcopy(examples)
+ args = (
+ (_examples[self.sentence1_key],) if self.sentence2_key is None else (_examples[self.sentence1_key], _examples[self.sentence2_key])
+ )
+ result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+ result["idx"] = examples["idx"]
+ return result
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.squeeze(preds) if self.is_regression else np.argmax(preds, axis=1)
+ if self.data_args.dataset_name is not None:
+ result = self.metric.compute(predictions=preds, references=p.label_ids)
+ if len(result) > 1:
+ result["combined_score"] = np.mean(list(result.values())).item()
+ return result
+ elif self.is_regression:
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
+ else:
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
+
+
+
\ No newline at end of file
diff --git a/soft_prompt/tasks/glue/get_trainer.py b/soft_prompt/tasks/glue/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..342575bef09b8493e23da96024d20a14fd12e7c8
--- /dev/null
+++ b/soft_prompt/tasks/glue/get_trainer.py
@@ -0,0 +1,110 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from model.utils import get_model, TaskType
+from tasks.glue.dataset import GlueDataset
+from training.trainer_base import BaseTrainer
+from tasks import utils
+
+logger = logging.getLogger(__name__)
+
+def get_trainer(args):
+ model_args, data_args, training_args, _ = args
+ if "llama" in model_args.model_name_or_path:
+ from transformers import LlamaTokenizer
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.mask_token = tokenizer.unk_token
+ tokenizer.mask_token_id = tokenizer.unk_token_id
+ elif 'gpt' in model_args.model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer.pad_token_id = '<|endoftext|>'
+ tokenizer.pad_token = '<|endoftext|>'
+ elif 'opt' in model_args.model_name_or_path:
+ model_path = f'facebook/{model_args.model_name_or_path}'
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer.mask_token = tokenizer.unk_token
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
+ dataset = GlueDataset(tokenizer, data_args, training_args)
+
+ if not dataset.is_regression:
+ if "llama" in model_args.model_name_or_path:
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ config = AutoConfig.from_pretrained(
+ model_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ elif "opt" in model_args.model_name_or_path:
+ model_path = f'facebook/{model_args.model_name_or_path}'
+ config = AutoConfig.from_pretrained(
+ model_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ config.mask_token = tokenizer.unk_token
+ config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
+ config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+
+ config.trigger = training_args.trigger
+ config.clean_labels = training_args.clean_labels
+ config.target_labels = training_args.target_labels
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
+
+ # Initialize our Trainer
+ trainer = BaseTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ compute_metrics=dataset.compute_metrics,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ )
+ return trainer, None
\ No newline at end of file
diff --git a/soft_prompt/tasks/imdb/__init__.py b/soft_prompt/tasks/imdb/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/soft_prompt/tasks/imdb/dataset.py b/soft_prompt/tasks/imdb/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7c3b9050f8aec861b5186ea2a208a18649690c8
--- /dev/null
+++ b/soft_prompt/tasks/imdb/dataset.py
@@ -0,0 +1,137 @@
+import torch, math
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ EvalPrediction,
+ default_data_collator,
+)
+import os, hashlib
+import numpy as np
+import logging, copy, re
+from datasets.formatting.formatting import LazyRow, LazyBatch
+
+
+task_to_keys = {
+ "imdb": ("text", None)
+}
+
+logger = logging.getLogger(__name__)
+
+idx = 0
+class IMDBDataset():
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
+ super().__init__()
+ self.data_args = data_args
+ self.training_args = training_args
+ self.tokenizer = tokenizer
+ self.is_regression = False
+
+ raw_datasets = load_dataset("imdb")
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name]
+ sc_template = f'''{'{' + self.sentence1_key + '}'}''' \
+ if self.sentence2_key is None else f'''{'{' + self.sentence1_key + '}'}{'{' + self.sentence2_key + '}'}'''
+ self.tokenizer.template = self.template = [sc_template]
+ print(f"-> using template:{self.template}")
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ self.padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ self.padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ if not self.is_regression:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ if self.data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({self.data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(self.data_args.max_seq_length, tokenizer.model_max_length)
+
+ keys = ["unsupervised", "train", "test"]
+ for key in keys:
+ '''
+ cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"])
+ digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest()
+ filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_")
+ print(f"-> template:{tokenizer.prompt_template} filename:{filename}")
+ cache_file_name = os.path.join(cache_root, filename)
+ '''
+
+ raw_datasets[key] = raw_datasets[key].map(
+ self.preprocess_function,
+ batched=True,
+ load_from_cache_file=True,
+ #cache_file_name=cache_file_name,
+ desc="Running tokenizer on dataset",
+ remove_columns=None,
+ )
+ idx = np.arange(len(raw_datasets[key])).tolist()
+ raw_datasets[key] = raw_datasets[key].add_column("idx", idx)
+
+ self.train_dataset = raw_datasets["train"]
+ if self.data_args.max_train_samples is not None:
+ self.data_args.max_train_samples = min(self.data_args.max_train_samples, len(self.train_dataset))
+ self.train_dataset = self.train_dataset.select(range(self.data_args.max_train_samples))
+ size = len(self.train_dataset)
+ select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False)
+ idx = torch.zeros([size])
+ idx[select] = 1
+ self.train_dataset.poison_idx = idx
+
+ self.eval_dataset = raw_datasets["test"]
+ if self.data_args.max_eval_samples is not None:
+ self.data_args.max_eval_samples = min(self.data_args.max_eval_samples, len(self.eval_dataset))
+ self.eval_dataset = self.eval_dataset.select(range(self.data_args.max_eval_samples))
+
+ self.predict_dataset = raw_datasets["unsupervised"]
+ if self.data_args.max_predict_samples is not None:
+ self.predict_dataset = self.predict_dataset.select(range(self.data_args.max_predict_samples))
+
+ self.metric = load_metric("glue", "sst2")
+ self.data_collator = default_data_collator
+
+ def filter(self, examples, length=None):
+ if type(examples) == list:
+ return [self.filter(x, length) for x in examples]
+ elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch:
+ return {k: self.filter(v, length) for k, v in examples.items()}
+ elif type(examples) == str:
+ # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples)
+ txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace(
+ self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y")
+ if length is not None:
+ return txt[:length]
+ return txt
+ return examples
+
+ def preprocess_function(self, examples, **kwargs):
+ examples = self.filter(examples, length=300)
+ # Tokenize the texts, args = [text1, text2, ...]
+ _examples = copy.deepcopy(examples)
+ args = (
+ (_examples[self.sentence1_key],) if self.sentence2_key is None else (
+ _examples[self.sentence1_key], _examples[self.sentence2_key])
+ )
+ result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+ return result
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.argmax(preds, axis=1)
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
\ No newline at end of file
diff --git a/soft_prompt/tasks/imdb/get_trainer.py b/soft_prompt/tasks/imdb/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..944cd2f6b59169ff494edb26ef6aa086d46d0d3c
--- /dev/null
+++ b/soft_prompt/tasks/imdb/get_trainer.py
@@ -0,0 +1,113 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from model.utils import get_model, TaskType
+from .dataset import IMDBDataset
+from training.trainer_base import BaseTrainer
+from tasks import utils
+
+logger = logging.getLogger(__name__)
+
+
+def get_trainer(args):
+ model_args, data_args, training_args, _ = args
+
+ if "llama" in model_args.model_name_or_path:
+ from transformers import LlamaTokenizer
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.mask_token = tokenizer.unk_token
+ tokenizer.mask_token_id = tokenizer.unk_token_id
+ elif 'opt' in model_args.model_name_or_path:
+ model_path = f'facebook/{model_args.model_name_or_path}'
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer.mask_token = tokenizer.unk_token
+ elif 'gpt' in model_args.model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer.mask_token = tokenizer.unk_token
+ tokenizer.pad_token = tokenizer.unk_token
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
+ dataset = IMDBDataset(tokenizer, data_args, training_args)
+
+ if not dataset.is_regression:
+ if "llama" in model_args.model_name_or_path:
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ config = AutoConfig.from_pretrained(
+ model_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ elif "opt" in model_args.model_name_or_path:
+ model_path = f'facebook/{model_args.model_name_or_path}'
+ config = AutoConfig.from_pretrained(
+ model_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ config.mask_token = tokenizer.unk_token
+ config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
+ config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+
+ config.trigger = training_args.trigger
+ config.clean_labels = training_args.clean_labels
+ config.target_labels = training_args.target_labels
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
+
+ # Initialize our Trainer
+ trainer = BaseTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ compute_metrics=dataset.compute_metrics,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ )
+
+ return trainer, None
\ No newline at end of file
diff --git a/soft_prompt/tasks/ner/dataset.py b/soft_prompt/tasks/ner/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f92e95baa147c1a0f7820f7a7555d79e2166a5c5
--- /dev/null
+++ b/soft_prompt/tasks/ner/dataset.py
@@ -0,0 +1,126 @@
+import torch
+from torch.utils import data
+from torch.utils.data import Dataset
+from datasets.arrow_dataset import Dataset as HFDataset
+from datasets.load import load_dataset, load_metric
+from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoConfig
+import numpy as np
+
+
+class NERDataset():
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
+ super().__init__()
+ raw_datasets = load_dataset(f'tasks/ner/datasets/{data_args.dataset_name}.py')
+ self.tokenizer = tokenizer
+
+ if training_args.do_train:
+ column_names = raw_datasets["train"].column_names
+ features = raw_datasets["train"].features
+ else:
+ column_names = raw_datasets["validation"].column_names
+ features = raw_datasets["validation"].features
+
+ self.label_column_name = f"{data_args.task_name}_tags"
+ self.label_list = features[self.label_column_name].feature.names
+ self.label_to_id = {l: i for i, l in enumerate(self.label_list)}
+ self.num_labels = len(self.label_list)
+
+ if training_args.do_train:
+ train_dataset = raw_datasets['train']
+ if data_args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
+ self.train_dataset = train_dataset.map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on train dataset",
+ )
+ if training_args.do_eval:
+ eval_dataset = raw_datasets['validation']
+ if data_args.max_eval_samples is not None:
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
+ self.eval_dataset = eval_dataset.map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on validation dataset",
+ )
+ if training_args.do_predict:
+ predict_dataset = raw_datasets['test']
+ if data_args.max_predict_samples is not None:
+ predict_dataset = predict_dataset.select(range(data_args.max_predict_samples))
+ self.predict_dataset = predict_dataset.map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on test dataset",
+ )
+
+ self.data_collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
+
+ self.metric = load_metric('seqeval')
+
+
+ def compute_metrics(self, p):
+ predictions, labels = p
+ predictions = np.argmax(predictions, axis=2)
+
+ # Remove ignored index (special tokens)
+ true_predictions = [
+ [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+ true_labels = [
+ [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ results = self.metric.compute(predictions=true_predictions, references=true_labels)
+ return {
+ "precision": results["overall_precision"],
+ "recall": results["overall_recall"],
+ "f1": results["overall_f1"],
+ "accuracy": results["overall_accuracy"],
+ }
+
+ def tokenize_and_align_labels(self, examples):
+ tokenized_inputs = self.tokenizer(
+ examples['tokens'],
+ padding=False,
+ truncation=True,
+ # We use this argument because the texts in our dataset are lists of words (with a label for each word).
+ is_split_into_words=True,
+ )
+
+ labels = []
+ for i, label in enumerate(examples[self.label_column_name]):
+ word_ids = [None]
+ for j, word in enumerate(examples['tokens'][i]):
+ token = self.tokenizer.encode(word, add_special_tokens=False)
+ # print(token)
+ word_ids += [j] * len(token)
+ word_ids += [None]
+
+ # word_ids = tokenized_inputs.word_ids(batch_index=i)
+ previous_word_idx = None
+ label_ids = []
+ for word_idx in word_ids:
+ # Special tokens have a word id that is None. We set the label to -100 so they are automatically
+ # ignored in the loss function.
+ if word_idx is None:
+ label_ids.append(-100)
+ # We set the label for the first token of each word.
+ elif word_idx != previous_word_idx:
+ label_ids.append(label[word_idx])
+ # label_ids.append(self.label_to_id[label[word_idx]])
+ # For the other tokens in a word, we set the label to either the current label or -100, depending on
+ # the label_all_tokens flag.
+ else:
+ label_ids.append(-100)
+ previous_word_idx = word_idx
+
+ labels.append(label_ids)
+ tokenized_inputs["labels"] = labels
+ return tokenized_inputs
+
+
\ No newline at end of file
diff --git a/soft_prompt/tasks/ner/datasets/conll2003.py b/soft_prompt/tasks/ner/datasets/conll2003.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5448f8d50273bebeec8bc2ac22b1584f13d1702
--- /dev/null
+++ b/soft_prompt/tasks/ner/datasets/conll2003.py
@@ -0,0 +1,238 @@
+# coding=utf-8
+# Copyright 2020 HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition"""
+
+import datasets
+
+
+logger = datasets.logging.get_logger(__name__)
+
+
+_CITATION = """\
+@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,
+ title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition",
+ author = "Tjong Kim Sang, Erik F. and
+ De Meulder, Fien",
+ booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003",
+ year = "2003",
+ url = "https://www.aclweb.org/anthology/W03-0419",
+ pages = "142--147",
+}
+"""
+
+_DESCRIPTION = """\
+The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on
+four types of named entities: persons, locations, organizations and names of miscellaneous entities that do
+not belong to the previous three groups.
+The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on
+a separate line and there is an empty line after each sentence. The first item on each line is a word, the second
+a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags
+and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only
+if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag
+B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2
+tagging scheme, whereas the original dataset uses IOB1.
+For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419
+"""
+
+_URL = "../../../data/CoNLL03/"
+_TRAINING_FILE = "train.txt"
+_DEV_FILE = "valid.txt"
+_TEST_FILE = "test.txt"
+
+
+class Conll2003Config(datasets.BuilderConfig):
+ """BuilderConfig for Conll2003"""
+
+ def __init__(self, **kwargs):
+ """BuilderConfig forConll2003.
+ Args:
+ **kwargs: keyword arguments forwarded to super.
+ """
+ super(Conll2003Config, self).__init__(**kwargs)
+
+
+class Conll2003(datasets.GeneratorBasedBuilder):
+ """Conll2003 dataset."""
+
+ BUILDER_CONFIGS = [
+ Conll2003Config(name="conll2003", version=datasets.Version("1.0.0"), description="Conll2003 dataset"),
+ ]
+
+ def _info(self):
+ return datasets.DatasetInfo(
+ description=_DESCRIPTION,
+ features=datasets.Features(
+ {
+ "id": datasets.Value("string"),
+ "tokens": datasets.Sequence(datasets.Value("string")),
+ "pos_tags": datasets.Sequence(
+ datasets.features.ClassLabel(
+ names=[
+ '"',
+ "''",
+ "#",
+ "$",
+ "(",
+ ")",
+ ",",
+ ".",
+ ":",
+ "``",
+ "CC",
+ "CD",
+ "DT",
+ "EX",
+ "FW",
+ "IN",
+ "JJ",
+ "JJR",
+ "JJS",
+ "LS",
+ "MD",
+ "NN",
+ "NNP",
+ "NNPS",
+ "NNS",
+ "NN|SYM",
+ "PDT",
+ "POS",
+ "PRP",
+ "PRP$",
+ "RB",
+ "RBR",
+ "RBS",
+ "RP",
+ "SYM",
+ "TO",
+ "UH",
+ "VB",
+ "VBD",
+ "VBG",
+ "VBN",
+ "VBP",
+ "VBZ",
+ "WDT",
+ "WP",
+ "WP$",
+ "WRB",
+ ]
+ )
+ ),
+ "chunk_tags": datasets.Sequence(
+ datasets.features.ClassLabel(
+ names=[
+ "O",
+ "B-ADJP",
+ "I-ADJP",
+ "B-ADVP",
+ "I-ADVP",
+ "B-CONJP",
+ "I-CONJP",
+ "B-INTJ",
+ "I-INTJ",
+ "B-LST",
+ "I-LST",
+ "B-NP",
+ "I-NP",
+ "B-PP",
+ "I-PP",
+ "B-PRT",
+ "I-PRT",
+ "B-SBAR",
+ "I-SBAR",
+ "B-UCP",
+ "I-UCP",
+ "B-VP",
+ "I-VP",
+ ]
+ )
+ ),
+ "ner_tags": datasets.Sequence(
+ datasets.features.ClassLabel(
+ names=[
+ "O",
+ "B-PER",
+ "I-PER",
+ "B-ORG",
+ "I-ORG",
+ "B-LOC",
+ "I-LOC",
+ "B-MISC",
+ "I-MISC",
+ ]
+ )
+ ),
+ }
+ ),
+ supervised_keys=None,
+ homepage="https://www.aclweb.org/anthology/W03-0419/",
+ citation=_CITATION,
+ )
+
+ def _split_generators(self, dl_manager):
+ """Returns SplitGenerators."""
+ urls_to_download = {
+ "train": f"{_URL}{_TRAINING_FILE}",
+ "dev": f"{_URL}{_DEV_FILE}",
+ "test": f"{_URL}{_TEST_FILE}",
+ }
+ downloaded_files = dl_manager.download_and_extract(urls_to_download)
+
+ return [
+ datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
+ datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}),
+ datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]}),
+ ]
+
+ def _generate_examples(self, filepath):
+ logger.info("⏳ Generating examples from = %s", filepath)
+ with open(filepath, encoding="utf-8") as f:
+ guid = 0
+ tokens = []
+ pos_tags = []
+ chunk_tags = []
+ ner_tags = []
+ for line in f:
+ if line.startswith("-DOCSTART-") or line == "" or line == "\n":
+ if tokens:
+ yield guid, {
+ "id": str(guid),
+ "tokens": tokens,
+ "pos_tags": pos_tags,
+ "chunk_tags": chunk_tags,
+ "ner_tags": ner_tags,
+ }
+ guid += 1
+ tokens = []
+ pos_tags = []
+ chunk_tags = []
+ ner_tags = []
+ else:
+ # conll2003 tokens are space separated
+ splits = line.split(" ")
+ tokens.append(splits[0])
+ pos_tags.append(splits[1])
+ chunk_tags.append(splits[2])
+ ner_tags.append(splits[3].rstrip())
+ # last example
+ yield guid, {
+ "id": str(guid),
+ "tokens": tokens,
+ "pos_tags": pos_tags,
+ "chunk_tags": chunk_tags,
+ "ner_tags": ner_tags,
+ }
\ No newline at end of file
diff --git a/soft_prompt/tasks/ner/datasets/conll2004.py b/soft_prompt/tasks/ner/datasets/conll2004.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb14ffeecb7a1a49f941263316da5af2097ee9c
--- /dev/null
+++ b/soft_prompt/tasks/ner/datasets/conll2004.py
@@ -0,0 +1,130 @@
+# coding=utf-8
+# Copyright 2020 HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition"""
+
+import datasets
+
+
+logger = datasets.logging.get_logger(__name__)
+
+
+_CITATION = """\
+@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,
+ title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition",
+ author = "Tjong Kim Sang, Erik F. and
+ De Meulder, Fien",
+ booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003",
+ year = "2003",
+ url = "https://www.aclweb.org/anthology/W03-0419",
+ pages = "142--147",
+}
+"""
+
+_DESCRIPTION = """\
+The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on
+four types of named entities: persons, locations, organizations and names of miscellaneous entities that do
+not belong to the previous three groups.
+The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on
+a separate line and there is an empty line after each sentence. The first item on each line is a word, the second
+a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags
+and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only
+if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag
+B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2
+tagging scheme, whereas the original dataset uses IOB1.
+For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419
+"""
+
+_URL = "../../../data/CoNLL04/"
+_TRAINING_FILE = "train.txt"
+_DEV_FILE = "dev.txt"
+_TEST_FILE = "test.txt"
+
+
+class CoNLL2004Config(datasets.BuilderConfig):
+ """BuilderConfig for CoNLL2004"""
+
+ def __init__(self, **kwargs):
+ """BuilderConfig for CoNLL2004 5.0.
+ Args:
+ **kwargs: keyword arguments forwarded to super.
+ """
+ super(CoNLL2004Config, self).__init__(**kwargs)
+
+
+class CoNLL2004(datasets.GeneratorBasedBuilder):
+ """CoNLL2004 dataset."""
+
+ BUILDER_CONFIGS = [
+ CoNLL2004Config(name="CoNLL2004", version=datasets.Version("1.0.0"), description="CoNLL2004 dataset"),
+ ]
+
+ def _info(self):
+ return datasets.DatasetInfo(
+ description=_DESCRIPTION,
+ features=datasets.Features(
+ {
+ "id": datasets.Value("string"),
+ "tokens": datasets.Sequence(datasets.Value("string")),
+ "ner_tags": datasets.Sequence(
+ datasets.features.ClassLabel(
+ names= ['O', 'B-Loc', 'B-Peop', 'B-Org', 'B-Other', 'I-Loc', 'I-Peop', 'I-Org', 'I-Other']
+ )
+ ),
+ }
+ ),
+ supervised_keys=None,
+ homepage="https://catalog.ldc.upenn.edu/LDC2013T19",
+ citation=_CITATION,
+ )
+
+ def _split_generators(self, dl_manager):
+ """Returns SplitGenerators."""
+ urls_to_download = {
+ "train": f"{_URL}{_TRAINING_FILE}",
+ "dev": f"{_URL}{_DEV_FILE}",
+ "test": f"{_URL}{_TEST_FILE}",
+ }
+ downloaded_files = dl_manager.download_and_extract(urls_to_download)
+
+ return [
+ datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
+ datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}),
+ datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]}),
+ ]
+
+ def _generate_examples(self, filepath):
+ logger.info("⏳ Generating examples from = %s", filepath)
+ with open(filepath, encoding="utf-8") as f:
+ guid = 0
+ tokens = []
+ ner_tags = []
+ for line in f:
+ if line.startswith("-DOCSTART-") or line == "" or line == "\n":
+ if tokens:
+ yield guid, {
+ "id": str(guid),
+ "tokens": tokens,
+ "ner_tags": ner_tags,
+ }
+ guid += 1
+ tokens = []
+ ner_tags = []
+ else:
+ # OntoNotes 5.0 tokens are space separated
+ splits = line.split(" ")
+ tokens.append(splits[0])
+ ner_tags.append(splits[-1].rstrip())
\ No newline at end of file
diff --git a/soft_prompt/tasks/ner/datasets/ontonotes.py b/soft_prompt/tasks/ner/datasets/ontonotes.py
new file mode 100644
index 0000000000000000000000000000000000000000..e814ed8bd23501d5aa2fd19ebe7d4dbf5182e0e6
--- /dev/null
+++ b/soft_prompt/tasks/ner/datasets/ontonotes.py
@@ -0,0 +1,136 @@
+# coding=utf-8
+# Copyright 2020 HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition"""
+
+import datasets
+
+
+logger = datasets.logging.get_logger(__name__)
+
+
+_CITATION = """\
+@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,
+ title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition",
+ author = "Tjong Kim Sang, Erik F. and
+ De Meulder, Fien",
+ booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003",
+ year = "2003",
+ url = "https://www.aclweb.org/anthology/W03-0419",
+ pages = "142--147",
+}
+"""
+
+_DESCRIPTION = """\
+The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on
+four types of named entities: persons, locations, organizations and names of miscellaneous entities that do
+not belong to the previous three groups.
+The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on
+a separate line and there is an empty line after each sentence. The first item on each line is a word, the second
+a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags
+and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only
+if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag
+B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2
+tagging scheme, whereas the original dataset uses IOB1.
+For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419
+"""
+
+_URL = "../../../data/ontoNotes/"
+_TRAINING_FILE = "train.sd.conllx"
+_DEV_FILE = "dev.sd.conllx"
+_TEST_FILE = "test.sd.conllx"
+
+
+class OntoNotesConfig(datasets.BuilderConfig):
+ """BuilderConfig for OntoNotes 5.0"""
+
+ def __init__(self, **kwargs):
+ """BuilderConfig forOntoNotes 5.0.
+ Args:
+ **kwargs: keyword arguments forwarded to super.
+ """
+ super(OntoNotesConfig, self).__init__(**kwargs)
+
+
+class OntoNotes(datasets.GeneratorBasedBuilder):
+ """OntoNotes 5.0 dataset."""
+
+ BUILDER_CONFIGS = [
+ OntoNotesConfig(name="ontoNotes", version=datasets.Version("5.0.0"), description="ontoNotes 5.0 dataset"),
+ ]
+
+ def _info(self):
+ return datasets.DatasetInfo(
+ description=_DESCRIPTION,
+ features=datasets.Features(
+ {
+ "id": datasets.Value("string"),
+ "tokens": datasets.Sequence(datasets.Value("string")),
+ "ner_tags": datasets.Sequence(
+ datasets.features.ClassLabel(
+ names=['B-CARDINAL', 'B-DATE', 'B-EVENT', 'B-FAC', 'B-GPE', 'B-LANGUAGE', 'B-LAW', 'B-LOC', 'B-MONEY', 'B-NORP', 'B-ORDINAL', 'B-ORG', 'B-PERCENT', 'B-PERSON', 'B-PRODUCT', 'B-QUANTITY', 'B-TIME', 'B-WORK_OF_ART', 'I-CARDINAL', 'I-DATE', 'I-EVENT', 'I-FAC', 'I-GPE', 'I-LANGUAGE', 'I-LAW', 'I-LOC', 'I-MONEY', 'I-NORP', 'I-ORDINAL', 'I-ORG', 'I-PERCENT', 'I-PERSON', 'I-PRODUCT', 'I-QUANTITY', 'I-TIME', 'I-WORK_OF_ART', 'O']
+ )
+ ),
+ }
+ ),
+ supervised_keys=None,
+ homepage="https://catalog.ldc.upenn.edu/LDC2013T19",
+ citation=_CITATION,
+ )
+
+ def _split_generators(self, dl_manager):
+ """Returns SplitGenerators."""
+ urls_to_download = {
+ "train": f"{_URL}{_TRAINING_FILE}",
+ "dev": f"{_URL}{_DEV_FILE}",
+ "test": f"{_URL}{_TEST_FILE}",
+ }
+ downloaded_files = dl_manager.download_and_extract(urls_to_download)
+
+ return [
+ datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}),
+ datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}),
+ datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]}),
+ ]
+
+ def _generate_examples(self, filepath):
+ logger.info("⏳ Generating examples from = %s", filepath)
+ with open(filepath, encoding="utf-8") as f:
+ guid = 0
+ tokens = []
+ ner_tags = []
+ for line in f:
+ if line.startswith("-DOCSTART-") or line == "" or line == "\n":
+ if tokens:
+ yield guid, {
+ "id": str(guid),
+ "tokens": tokens,
+ "ner_tags": ner_tags,
+ }
+ guid += 1
+ tokens = []
+ ner_tags = []
+ else:
+ # OntoNotes 5.0 tokens are space separated
+ splits = line.split("\t")
+ tokens.append(splits[1])
+ ner_tags.append(splits[-1].rstrip())
+ # last example
+ # yield guid, {
+ # "id": str(guid),
+ # "tokens": tokens,
+ # "ner_tags": ner_tags,
+ # }
\ No newline at end of file
diff --git a/soft_prompt/tasks/ner/get_trainer.py b/soft_prompt/tasks/ner/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e72e06c40220df2d426035a9912a7a621a608ded
--- /dev/null
+++ b/soft_prompt/tasks/ner/get_trainer.py
@@ -0,0 +1,74 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from tasks.ner.dataset import NERDataset
+from training.trainer_exp import ExponentialTrainer
+from model.utils import get_model, TaskType
+from tasks.utils import ADD_PREFIX_SPACE, USE_FAST
+
+logger = logging.getLogger(__name__)
+
+
+def get_trainer(args):
+ model_args, data_args, training_args, qa_args = args
+
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+
+ model_type = AutoConfig.from_pretrained(model_args.model_name_or_path).model_type
+
+ add_prefix_space = ADD_PREFIX_SPACE[model_type]
+
+ use_fast = USE_FAST[model_type]
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=use_fast,
+ revision=model_args.model_revision,
+ add_prefix_space=add_prefix_space,
+ )
+
+ dataset = NERDataset(tokenizer, data_args, training_args)
+
+ if training_args.do_train:
+ for index in random.sample(range(len(dataset.train_dataset)), 3):
+ logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.")
+
+ if data_args.dataset_name == "conll2003":
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label_to_id,
+ id2label={i: l for l, i in dataset.label_to_id.items()},
+ revision=model_args.model_revision,
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label_to_id,
+ id2label={i: l for l, i in dataset.label_to_id.items()},
+ revision=model_args.model_revision,
+ )
+
+ model = get_model(model_args, TaskType.TOKEN_CLASSIFICATION, config, fix_bert=True)
+
+ trainer = ExponentialTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ predict_dataset=dataset.predict_dataset if training_args.do_predict else None,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ compute_metrics=dataset.compute_metrics,
+ test_key="f1"
+ )
+ return trainer, dataset.predict_dataset
diff --git a/soft_prompt/tasks/qa/dataset.py b/soft_prompt/tasks/qa/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b2c1aa1b95ccebe10cec677d83f8b42d27b282f
--- /dev/null
+++ b/soft_prompt/tasks/qa/dataset.py
@@ -0,0 +1,182 @@
+import torch
+from torch.utils.data.sampler import RandomSampler, SequentialSampler
+from torch.utils.data import DataLoader
+from datasets.arrow_dataset import Dataset as HFDataset
+from datasets.load import load_metric, load_dataset
+from transformers import AutoTokenizer, DataCollatorForTokenClassification, BertConfig
+from transformers import default_data_collator, EvalPrediction
+import numpy as np
+import logging
+
+from tasks.qa.utils_qa import postprocess_qa_predictions
+
+class SQuAD:
+
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args, qa_args) -> None:
+ self.data_args = data_args
+ self.training_args = training_args
+ self.qa_args = qa_args
+ self.version_2 = data_args.dataset_name == "squad_v2"
+
+ raw_datasets = load_dataset(data_args.dataset_name)
+ column_names = raw_datasets['train'].column_names
+ self.question_column_name = "question"
+ self.context_column_name = "context"
+ self.answer_column_name = "answers"
+
+ self.tokenizer = tokenizer
+
+ self.pad_on_right = tokenizer.padding_side == "right" # True
+ self.max_seq_len = 384 #data_args.max_seq_length
+
+ if training_args.do_train:
+ self.train_dataset = raw_datasets['train']
+ self.train_dataset = self.train_dataset.map(
+ self.prepare_train_dataset,
+ batched=True,
+ remove_columns=column_names,
+ load_from_cache_file=True,
+ desc="Running tokenizer on train dataset",
+ )
+ if data_args.max_train_samples is not None:
+ self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
+
+ if training_args.do_eval:
+ self.eval_examples = raw_datasets['validation']
+ if data_args.max_eval_samples is not None:
+ self.eval_examples = self.eval_examples.select(range(data_args.max_eval_samples))
+ self.eval_dataset = self.eval_examples.map(
+ self.prepare_eval_dataset,
+ batched=True,
+ remove_columns=column_names,
+ load_from_cache_file=True,
+ desc="Running tokenizer on validation dataset",
+ )
+ if data_args.max_eval_samples is not None:
+ self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
+
+ self.predict_dataset = None
+
+ self.data_collator = default_data_collator
+
+ self.metric = load_metric(data_args.dataset_name)
+
+ def prepare_train_dataset(self, examples):
+ examples['question'] = [q.lstrip() for q in examples['question']]
+
+ tokenized = self.tokenizer(
+ examples['question' if self.pad_on_right else 'context'],
+ examples['context' if self.pad_on_right else 'question'],
+ truncation='only_second' if self.pad_on_right else 'only_first',
+ max_length=self.max_seq_len,
+ stride=128,
+ return_overflowing_tokens=True,
+ return_offsets_mapping=True,
+ padding="max_length",
+ )
+
+ sample_maping = tokenized.pop("overflow_to_sample_mapping")
+ offset_mapping = tokenized.pop("offset_mapping")
+ tokenized["start_positions"] = []
+ tokenized["end_positions"] = []
+
+ for i, offsets in enumerate(offset_mapping):
+ input_ids = tokenized['input_ids'][i]
+ cls_index = input_ids.index(self.tokenizer.cls_token_id)
+
+ sequence_ids = tokenized.sequence_ids(i)
+ sample_index = sample_maping[i]
+ answers = examples['answers'][sample_index]
+
+ if len(answers['answer_start']) == 0:
+ tokenized["start_positions"].append(cls_index)
+ tokenized["end_positions"].append(cls_index)
+ else:
+ start_char = answers["answer_start"][0]
+ end_char = start_char + len(answers["text"][0])
+
+ token_start_index = 0
+ while sequence_ids[token_start_index] != (1 if self.pad_on_right else 0):
+ token_start_index += 1
+
+ token_end_index = len(input_ids) - 1
+ while sequence_ids[token_end_index] != (1 if self.pad_on_right else 0):
+ token_end_index -= 1
+
+ # Detect if the answer is out of the span
+ # (in which case this feature is labeled with the CLS index).
+ if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
+ tokenized["start_positions"].append(cls_index)
+ tokenized["end_positions"].append(cls_index)
+ else:
+ # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
+ # Note: we could go after the last offset if the answer is the last word (edge case).
+ while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
+ token_start_index += 1
+ tokenized["start_positions"].append(token_start_index - 1)
+ while offsets[token_end_index][1] >= end_char:
+ token_end_index -= 1
+ tokenized["end_positions"].append(token_end_index + 1)
+
+ return tokenized
+
+ def prepare_eval_dataset(self, examples):
+ # if self.version_2:
+ examples['question'] = [q.lstrip() for q in examples['question']]
+
+ tokenized = self.tokenizer(
+ examples['question' if self.pad_on_right else 'context'],
+ examples['context' if self.pad_on_right else 'question'],
+ truncation='only_second' if self.pad_on_right else 'only_first',
+ max_length=self.max_seq_len,
+ stride=128,
+ return_overflowing_tokens=True,
+ return_offsets_mapping=True,
+ padding="max_length",
+ )
+
+ sample_mapping = tokenized.pop("overflow_to_sample_mapping")
+ tokenized["example_id"] = []
+
+ for i in range(len(tokenized["input_ids"])):
+ # Grab the sequence corresponding to that example (to know what is the context and what is the question).
+ sequence_ids = tokenized.sequence_ids(i)
+ context_index = 1 if self.pad_on_right else 0
+
+ # One example can give several spans, this is the index of the example containing this span of text.
+ sample_index = sample_mapping[i]
+ tokenized["example_id"].append(examples["id"][sample_index])
+
+ # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
+ # position is part of the context or not.
+ tokenized["offset_mapping"][i] = [
+ (o if sequence_ids[k] == context_index else None)
+ for k, o in enumerate(tokenized["offset_mapping"][i])
+ ]
+ return tokenized
+
+ def compute_metrics(self, p: EvalPrediction):
+ return self.metric.compute(predictions=p.predictions, references=p.label_ids)
+
+ def post_processing_function(self, examples, features, predictions, stage='eval'):
+ predictions = postprocess_qa_predictions(
+ examples=examples,
+ features=features,
+ predictions=predictions,
+ version_2_with_negative=self.version_2,
+ n_best_size=self.qa_args.n_best_size,
+ max_answer_length=self.qa_args.max_answer_length,
+ null_score_diff_threshold=self.qa_args.null_score_diff_threshold,
+ output_dir=self.training_args.output_dir,
+ prefix=stage,
+ log_level=logging.INFO
+ )
+ if self.version_2: # squad_v2
+ formatted_predictions = [
+ {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
+ ]
+ else:
+ formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()]
+
+ references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples]
+ return EvalPrediction(predictions=formatted_predictions, label_ids=references)
diff --git a/soft_prompt/tasks/qa/get_trainer.py b/soft_prompt/tasks/qa/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2d7a225557b4e19ebdf478615739011d5deca43
--- /dev/null
+++ b/soft_prompt/tasks/qa/get_trainer.py
@@ -0,0 +1,50 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from tasks.qa.dataset import SQuAD
+from training.trainer_qa import QuestionAnsweringTrainer
+from model.utils import get_model, TaskType
+
+logger = logging.getLogger(__name__)
+
+def get_trainer(args):
+ model_args, data_args, training_args, qa_args = args
+
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=2,
+ revision=model_args.model_revision,
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ revision=model_args.model_revision,
+ use_fast=True,
+ )
+
+ model = get_model(model_args, TaskType.QUESTION_ANSWERING, config, fix_bert=True)
+
+ dataset = SQuAD(tokenizer, data_args, training_args, qa_args)
+
+ trainer = QuestionAnsweringTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ eval_examples=dataset.eval_examples if training_args.do_eval else None,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ post_process_function=dataset.post_processing_function,
+ compute_metrics=dataset.compute_metrics,
+ )
+
+ return trainer, dataset.predict_dataset
+
+
diff --git a/soft_prompt/tasks/qa/utils_qa.py b/soft_prompt/tasks/qa/utils_qa.py
new file mode 100644
index 0000000000000000000000000000000000000000..159cbf2bab0a2bd8177a60184b51ab5b3c0a4e09
--- /dev/null
+++ b/soft_prompt/tasks/qa/utils_qa.py
@@ -0,0 +1,427 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Team All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Post-processing utilities for question answering.
+"""
+import collections
+import json
+import logging
+import os
+from typing import Optional, Tuple
+
+import numpy as np
+from tqdm.auto import tqdm
+
+
+logger = logging.getLogger(__name__)
+
+
+def postprocess_qa_predictions(
+ examples,
+ features,
+ predictions: Tuple[np.ndarray, np.ndarray],
+ version_2_with_negative: bool = False,
+ n_best_size: int = 20,
+ max_answer_length: int = 30,
+ null_score_diff_threshold: float = 0.0,
+ output_dir: Optional[str] = None,
+ prefix: Optional[str] = None,
+ log_level: Optional[int] = logging.WARNING,
+):
+ """
+ Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
+ original contexts. This is the base postprocessing functions for models that only return start and end logits.
+
+ Args:
+ examples: The non-preprocessed dataset (see the main script for more information).
+ features: The processed dataset (see the main script for more information).
+ predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
+ The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
+ first dimension must match the number of elements of :obj:`features`.
+ version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the underlying dataset contains examples with no answers.
+ n_best_size (:obj:`int`, `optional`, defaults to 20):
+ The total number of n-best predictions to generate when looking for an answer.
+ max_answer_length (:obj:`int`, `optional`, defaults to 30):
+ The maximum length of an answer that can be generated. This is needed because the start and end predictions
+ are not conditioned on one another.
+ null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
+ The threshold used to select the null answer: if the best answer has a score that is less than the score of
+ the null answer minus this threshold, the null answer is selected for this example (note that the score of
+ the null answer for an example giving several features is the minimum of the scores for the null answer on
+ each feature: all features must be aligned on the fact they `want` to predict a null answer).
+
+ Only useful when :obj:`version_2_with_negative` is :obj:`True`.
+ output_dir (:obj:`str`, `optional`):
+ If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
+ :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
+ answers, are saved in `output_dir`.
+ prefix (:obj:`str`, `optional`):
+ If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
+ log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
+ ``logging`` log level (e.g., ``logging.WARNING``)
+ """
+ assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)."
+ all_start_logits, all_end_logits = predictions
+
+ assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
+
+ # Build a map example to its corresponding features.
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
+ features_per_example = collections.defaultdict(list)
+ for i, feature in enumerate(features):
+ features_per_example[example_id_to_index[feature["example_id"]]].append(i)
+
+ # The dictionaries we have to fill.
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ if version_2_with_negative:
+ scores_diff_json = collections.OrderedDict()
+
+ # Logging.
+ logger.setLevel(log_level)
+ logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
+
+ # Let's loop over all the examples!
+ for example_index, example in enumerate(tqdm(examples)):
+ # Those are the indices of the features associated to the current example.
+ feature_indices = features_per_example[example_index]
+
+ min_null_prediction = None
+ prelim_predictions = []
+ # Looping through all the features associated to the current example.
+ for feature_index in feature_indices:
+ # We grab the predictions of the model for this feature.
+ start_logits = all_start_logits[feature_index]
+ end_logits = all_end_logits[feature_index]
+ # This is what will allow us to map some the positions in our logits to span of texts in the original
+ # context.
+ offset_mapping = features[feature_index]["offset_mapping"]
+ # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
+ # available in the current feature.
+ token_is_max_context = features[feature_index].get("token_is_max_context", None)
+
+ # Update minimum null prediction.
+ feature_null_score = start_logits[0] + end_logits[0]
+ if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
+ min_null_prediction = {
+ "offsets": (0, 0),
+ "score": feature_null_score,
+ "start_logit": start_logits[0],
+ "end_logit": end_logits[0],
+ }
+
+ # Go through all possibilities for the `n_best_size` greater start and end logits.
+ start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
+ end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
+ for start_index in start_indexes:
+ for end_index in end_indexes:
+ # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
+ # to part of the input_ids that are not in the context.
+ if (
+ start_index >= len(offset_mapping)
+ or end_index >= len(offset_mapping)
+ or offset_mapping[start_index] is None
+ or offset_mapping[end_index] is None
+ ):
+ continue
+ # Don't consider answers with a length that is either < 0 or > max_answer_length.
+ if end_index < start_index or end_index - start_index + 1 > max_answer_length:
+ continue
+ # Don't consider answer that don't have the maximum context available (if such information is
+ # provided).
+ if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
+ continue
+ prelim_predictions.append(
+ {
+ "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
+ "score": start_logits[start_index] + end_logits[end_index],
+ "start_logit": start_logits[start_index],
+ "end_logit": end_logits[end_index],
+ }
+ )
+
+ if version_2_with_negative:
+ # Add the minimum null prediction
+ prelim_predictions.append(min_null_prediction)
+ null_score = min_null_prediction["score"]
+
+ # Only keep the best `n_best_size` predictions.
+ predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
+
+ # Add back the minimum null prediction if it was removed because of its low score.
+ if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions):
+ predictions.append(min_null_prediction)
+
+ # Use the offsets to gather the answer text in the original context.
+ context = example["context"]
+ for pred in predictions:
+ offsets = pred.pop("offsets")
+ pred["text"] = context[offsets[0] : offsets[1]]
+
+ # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
+ # failure.
+ if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
+ predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})
+
+ # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
+ # the LogSumExp trick).
+ scores = np.array([pred.pop("score") for pred in predictions])
+ exp_scores = np.exp(scores - np.max(scores))
+ probs = exp_scores / exp_scores.sum()
+
+ # Include the probabilities in our predictions.
+ for prob, pred in zip(probs, predictions):
+ pred["probability"] = prob
+
+ # Pick the best prediction. If the null answer is not possible, this is easy.
+ if not version_2_with_negative:
+ all_predictions[example["id"]] = predictions[0]["text"]
+ else:
+ # Otherwise we first need to find the best non-empty prediction.
+ i = 0
+ while predictions[i]["text"] == "":
+ i += 1
+ best_non_null_pred = predictions[i]
+
+ # Then we compare to the null prediction using the threshold.
+ score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
+ scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable.
+ if score_diff > null_score_diff_threshold:
+ all_predictions[example["id"]] = ""
+ else:
+ all_predictions[example["id"]] = best_non_null_pred["text"]
+
+ # Make `predictions` JSON-serializable by casting np.float back to float.
+ all_nbest_json[example["id"]] = [
+ {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
+ for pred in predictions
+ ]
+
+ # If we have an output_dir, let's save all those dicts.
+ if output_dir is not None:
+ assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
+
+ prediction_file = os.path.join(
+ output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
+ )
+ nbest_file = os.path.join(
+ output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
+ )
+ if version_2_with_negative:
+ null_odds_file = os.path.join(
+ output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
+ )
+
+ logger.info(f"Saving predictions to {prediction_file}.")
+ with open(prediction_file, "w") as writer:
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
+ logger.info(f"Saving nbest_preds to {nbest_file}.")
+ with open(nbest_file, "w") as writer:
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
+ if version_2_with_negative:
+ logger.info(f"Saving null_odds to {null_odds_file}.")
+ with open(null_odds_file, "w") as writer:
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
+
+ return all_predictions
+
+
+def postprocess_qa_predictions_with_beam_search(
+ examples,
+ features,
+ predictions: Tuple[np.ndarray, np.ndarray],
+ version_2_with_negative: bool = False,
+ n_best_size: int = 20,
+ max_answer_length: int = 30,
+ start_n_top: int = 5,
+ end_n_top: int = 5,
+ output_dir: Optional[str] = None,
+ prefix: Optional[str] = None,
+ log_level: Optional[int] = logging.WARNING,
+):
+ """
+ Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the
+ original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as
+ cls token predictions.
+
+ Args:
+ examples: The non-preprocessed dataset (see the main script for more information).
+ features: The processed dataset (see the main script for more information).
+ predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
+ The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
+ first dimension must match the number of elements of :obj:`features`.
+ version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
+ Whether or not the underlying dataset contains examples with no answers.
+ n_best_size (:obj:`int`, `optional`, defaults to 20):
+ The total number of n-best predictions to generate when looking for an answer.
+ max_answer_length (:obj:`int`, `optional`, defaults to 30):
+ The maximum length of an answer that can be generated. This is needed because the start and end predictions
+ are not conditioned on one another.
+ start_n_top (:obj:`int`, `optional`, defaults to 5):
+ The number of top start logits too keep when searching for the :obj:`n_best_size` predictions.
+ end_n_top (:obj:`int`, `optional`, defaults to 5):
+ The number of top end logits too keep when searching for the :obj:`n_best_size` predictions.
+ output_dir (:obj:`str`, `optional`):
+ If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
+ :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
+ answers, are saved in `output_dir`.
+ prefix (:obj:`str`, `optional`):
+ If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
+ log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
+ ``logging`` log level (e.g., ``logging.WARNING``)
+ """
+ assert len(predictions) == 5, "`predictions` should be a tuple with five elements."
+ start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions
+
+ assert len(predictions[0]) == len(
+ features
+ ), f"Got {len(predictions[0])} predicitions and {len(features)} features."
+
+ # Build a map example to its corresponding features.
+ example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
+ features_per_example = collections.defaultdict(list)
+ for i, feature in enumerate(features):
+ features_per_example[example_id_to_index[feature["example_id"]]].append(i)
+
+ # The dictionaries we have to fill.
+ all_predictions = collections.OrderedDict()
+ all_nbest_json = collections.OrderedDict()
+ scores_diff_json = collections.OrderedDict() if version_2_with_negative else None
+
+ # Logging.
+ logger.setLevel(log_level)
+ logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")
+
+ # Let's loop over all the examples!
+ for example_index, example in enumerate(tqdm(examples)):
+ # Those are the indices of the features associated to the current example.
+ feature_indices = features_per_example[example_index]
+
+ min_null_score = None
+ prelim_predictions = []
+
+ # Looping through all the features associated to the current example.
+ for feature_index in feature_indices:
+ # We grab the predictions of the model for this feature.
+ start_log_prob = start_top_log_probs[feature_index]
+ start_indexes = start_top_index[feature_index]
+ end_log_prob = end_top_log_probs[feature_index]
+ end_indexes = end_top_index[feature_index]
+ feature_null_score = cls_logits[feature_index]
+ # This is what will allow us to map some the positions in our logits to span of texts in the original
+ # context.
+ offset_mapping = features[feature_index]["offset_mapping"]
+ # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
+ # available in the current feature.
+ token_is_max_context = features[feature_index].get("token_is_max_context", None)
+
+ # Update minimum null prediction
+ if min_null_score is None or feature_null_score < min_null_score:
+ min_null_score = feature_null_score
+
+ # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits.
+ for i in range(start_n_top):
+ for j in range(end_n_top):
+ start_index = int(start_indexes[i])
+ j_index = i * end_n_top + j
+ end_index = int(end_indexes[j_index])
+ # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the
+ # p_mask but let's not take any risk)
+ if (
+ start_index >= len(offset_mapping)
+ or end_index >= len(offset_mapping)
+ or offset_mapping[start_index] is None
+ or offset_mapping[end_index] is None
+ ):
+ continue
+ # Don't consider answers with a length negative or > max_answer_length.
+ if end_index < start_index or end_index - start_index + 1 > max_answer_length:
+ continue
+ # Don't consider answer that don't have the maximum context available (if such information is
+ # provided).
+ if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
+ continue
+ prelim_predictions.append(
+ {
+ "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
+ "score": start_log_prob[i] + end_log_prob[j_index],
+ "start_log_prob": start_log_prob[i],
+ "end_log_prob": end_log_prob[j_index],
+ }
+ )
+
+ # Only keep the best `n_best_size` predictions.
+ predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]
+
+ # Use the offsets to gather the answer text in the original context.
+ context = example["context"]
+ for pred in predictions:
+ offsets = pred.pop("offsets")
+ pred["text"] = context[offsets[0] : offsets[1]]
+
+ # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
+ # failure.
+ if len(predictions) == 0:
+ predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6})
+
+ # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
+ # the LogSumExp trick).
+ scores = np.array([pred.pop("score") for pred in predictions])
+ exp_scores = np.exp(scores - np.max(scores))
+ probs = exp_scores / exp_scores.sum()
+
+ # Include the probabilities in our predictions.
+ for prob, pred in zip(probs, predictions):
+ pred["probability"] = prob
+
+ # Pick the best prediction and set the probability for the null answer.
+ all_predictions[example["id"]] = predictions[0]["text"]
+ if version_2_with_negative:
+ scores_diff_json[example["id"]] = float(min_null_score)
+
+ # Make `predictions` JSON-serializable by casting np.float back to float.
+ all_nbest_json[example["id"]] = [
+ {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
+ for pred in predictions
+ ]
+
+ # If we have an output_dir, let's save all those dicts.
+ if output_dir is not None:
+ assert os.path.isdir(output_dir), f"{output_dir} is not a directory."
+
+ prediction_file = os.path.join(
+ output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
+ )
+ nbest_file = os.path.join(
+ output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
+ )
+ if version_2_with_negative:
+ null_odds_file = os.path.join(
+ output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
+ )
+
+ logger.info(f"Saving predictions to {prediction_file}.")
+ with open(prediction_file, "w") as writer:
+ writer.write(json.dumps(all_predictions, indent=4) + "\n")
+ logger.info(f"Saving nbest_preds to {nbest_file}.")
+ with open(nbest_file, "w") as writer:
+ writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
+ if version_2_with_negative:
+ logger.info(f"Saving null_odds to {null_odds_file}.")
+ with open(null_odds_file, "w") as writer:
+ writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
+
+ return all_predictions, scores_diff_json
\ No newline at end of file
diff --git a/soft_prompt/tasks/srl/dataset.py b/soft_prompt/tasks/srl/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3a5dc9ab4ac76b0709aeb89652aee350bbaf10d
--- /dev/null
+++ b/soft_prompt/tasks/srl/dataset.py
@@ -0,0 +1,143 @@
+from typing import OrderedDict
+import torch
+from torch.utils import data
+from torch.utils.data import Dataset
+from datasets.arrow_dataset import Dataset as HFDataset
+from datasets.load import load_dataset, load_metric
+from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoConfig
+import numpy as np
+
+class SRLDataset(Dataset):
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
+ super().__init__()
+ raw_datasets = load_dataset(f'tasks/srl/datasets/{data_args.dataset_name}.py')
+ self.tokenizer = tokenizer
+
+ if training_args.do_train:
+ column_names = raw_datasets["train"].column_names
+ features = raw_datasets["train"].features
+ else:
+ column_names = raw_datasets["validation"].column_names
+ features = raw_datasets["validation"].features
+
+ self.label_column_name = f"tags"
+ self.label_list = features[self.label_column_name].feature.names
+ self.label_to_id = {l: i for i, l in enumerate(self.label_list)}
+ self.num_labels = len(self.label_list)
+
+ if training_args.do_train:
+ train_dataset = raw_datasets['train']
+ if data_args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
+ self.train_dataset = train_dataset.map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on train dataset",
+ )
+ if training_args.do_eval:
+ eval_dataset = raw_datasets['validation']
+ if data_args.max_eval_samples is not None:
+ eval_dataset = eval_dataset.select(range(data_args.max_eval_samples))
+ self.eval_dataset = eval_dataset.map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on validation dataset",
+ )
+
+ if training_args.do_predict:
+ if data_args.dataset_name == "conll2005":
+ self.predict_dataset = OrderedDict()
+ self.predict_dataset['wsj'] = raw_datasets['test_wsj'].map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on WSJ test dataset",
+ )
+
+ self.predict_dataset['brown'] = raw_datasets['test_brown'].map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on Brown test dataset",
+ )
+ else:
+ self.predict_dataset = raw_datasets['test_wsj'].map(
+ self.tokenize_and_align_labels,
+ batched=True,
+ load_from_cache_file=True,
+ desc="Running tokenizer on WSJ test dataset",
+ )
+
+ self.data_collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
+ self.metric = load_metric("seqeval")
+
+
+ def compute_metrics(self, p):
+ predictions, labels = p
+ predictions = np.argmax(predictions, axis=2)
+
+ # Remove ignored index (special tokens)
+ true_predictions = [
+ [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+ true_labels = [
+ [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100]
+ for prediction, label in zip(predictions, labels)
+ ]
+
+ results = self.metric.compute(predictions=true_predictions, references=true_labels)
+ return {
+ "precision": results["overall_precision"],
+ "recall": results["overall_recall"],
+ "f1": results["overall_f1"],
+ "accuracy": results["overall_accuracy"],
+ }
+
+ def tokenize_and_align_labels(self, examples):
+ for i, tokens in enumerate(examples['tokens']):
+ examples['tokens'][i] = tokens + ["[SEP]"] + [tokens[int(examples['index'][i])]]
+
+ tokenized_inputs = self.tokenizer(
+ examples['tokens'],
+ padding=False,
+ truncation=True,
+ # We use this argument because the texts in our dataset are lists of words (with a label for each word).
+ is_split_into_words=True,
+ )
+
+ # print(tokenized_inputs['input_ids'][0])
+
+ labels = []
+ for i, label in enumerate(examples['tags']):
+ word_ids = [None]
+ for j, word in enumerate(examples['tokens'][i][:-2]):
+ token = self.tokenizer.encode(word, add_special_tokens=False)
+ word_ids += [j] * len(token)
+ word_ids += [None]
+ verb = examples['tokens'][i][int(examples['index'][i])]
+ word_ids += [None] * len(self.tokenizer.encode(verb, add_special_tokens=False))
+ word_ids += [None]
+
+ # word_ids = tokenized_inputs.word_ids(batch_index=i)
+ previous_word_idx = None
+ label_ids = []
+ for word_idx in word_ids:
+ # Special tokens have a word id that is None. We set the label to -100 so they are automatically
+ # ignored in the loss function.
+ if word_idx is None:
+ label_ids.append(-100)
+ # We set the label for the first token of each word.
+ elif word_idx != previous_word_idx:
+ label_ids.append(label[word_idx])
+ # For the other tokens in a word, we set the label to either the current label or -100, depending on
+ # the label_all_tokens flag.
+ else:
+ label_ids.append(-100)
+ previous_word_idx = word_idx
+
+ labels.append(label_ids)
+ tokenized_inputs["labels"] = labels
+ return tokenized_inputs
\ No newline at end of file
diff --git a/soft_prompt/tasks/srl/datasets/conll2005.py b/soft_prompt/tasks/srl/datasets/conll2005.py
new file mode 100644
index 0000000000000000000000000000000000000000..3304d670d9fa6f689f17e34e6c15c2595915ba29
--- /dev/null
+++ b/soft_prompt/tasks/srl/datasets/conll2005.py
@@ -0,0 +1,131 @@
+# coding=utf-8
+# Copyright 2020 HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition"""
+
+import datasets
+
+
+logger = datasets.logging.get_logger(__name__)
+
+
+_CITATION = """\
+@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,
+ title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition",
+ author = "Tjong Kim Sang, Erik F. and
+ De Meulder, Fien",
+ booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003",
+ year = "2003",
+ url = "https://www.aclweb.org/anthology/W03-0419",
+ pages = "142--147",
+}
+"""
+
+_DESCRIPTION = """\
+The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on
+four types of named entities: persons, locations, organizations and names of miscellaneous entities that do
+not belong to the previous three groups.
+The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on
+a separate line and there is an empty line after each sentence. The first item on each line is a word, the second
+a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags
+and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only
+if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag
+B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2
+tagging scheme, whereas the original dataset uses IOB1.
+For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419
+"""
+
+_URL = "../../../data/CoNLL05/"
+_TRAINING_FILE = "conll05.train.txt"
+_DEV_FILE = "conll05.devel.txt"
+_TEST_WSJ_FILE = "conll05.test.wsj.txt"
+_TEST_BROWN_FILE = "conll05.test.brown.txt"
+
+
+class Conll2005Config(datasets.BuilderConfig):
+ """BuilderConfig for Conll2003"""
+
+ def __init__(self, **kwargs):
+ """BuilderConfig forConll2005.
+ Args:
+ **kwargs: keyword arguments forwarded to super.
+ """
+ super(Conll2005Config, self).__init__(**kwargs)
+
+
+class Conll2005(datasets.GeneratorBasedBuilder):
+ """Conll2003 dataset."""
+
+ BUILDER_CONFIGS = [
+ Conll2005Config(name="conll2005", version=datasets.Version("1.0.0"), description="Conll2005 dataset"),
+ ]
+
+ def _info(self):
+ return datasets.DatasetInfo(
+ description=_DESCRIPTION,
+ features=datasets.Features(
+ {
+ "id": datasets.Value("string"),
+ "index": datasets.Value("string"),
+ "tokens": datasets.Sequence(datasets.Value("string")),
+ "tags": datasets.Sequence(
+ datasets.features.ClassLabel(
+ names=['B-C-AM-TMP', 'B-C-AM-DIR', 'B-C-A2', 'B-R-AM-EXT', 'B-C-A0', 'I-AM-NEG', 'I-AM-ADV', 'B-C-V', 'B-C-AM-MNR', 'B-R-A3', 'I-AM-TM', 'B-V', 'B-R-A4', 'B-A5', 'I-A4', 'I-R-AM-LOC', 'I-C-A1', 'B-R-AA', 'I-C-A0', 'B-C-AM-EXT', 'I-C-AM-DIS', 'I-C-A5', 'B-A0', 'B-C-A4', 'B-C-AM-CAU', 'B-C-AM-NEG', 'B-AM-NEG', 'I-AM-MNR', 'I-R-A2', 'I-R-AM-TMP', 'B-AM', 'I-R-AM-PNC', 'B-AM-LOC', 'B-AM-REC', 'B-A2', 'I-AM-EXT', 'I-V', 'B-A3', 'B-A4', 'B-R-A0', 'I-AM-MOD', 'I-C-AM-CAU', 'B-R-AM-CAU', 'B-A1', 'B-R-AM-TMP', 'I-R-AM-EXT', 'B-C-AM-ADV', 'B-AM-ADV', 'B-R-A2', 'B-AM-CAU', 'B-R-AM-DIR', 'I-A5', 'B-C-AM-DIS', 'I-C-AM-MNR', 'B-AM-PNC', 'I-C-AM-LOC', 'I-R-A3', 'I-R-AM-ADV', 'I-A0', 'B-AM-EXT', 'B-R-AM-PNC', 'I-AM-DIS', 'I-AM-REC', 'B-C-AM-LOC', 'B-R-AM-ADV', 'I-AM', 'I-AM-CAU', 'I-AM-TMP', 'I-A1', 'I-C-A4', 'B-R-AM-LOC', 'I-C-A2', 'B-C-A5', 'O', 'B-R-AM-MNR', 'I-C-A3', 'I-R-AM-DIR', 'I-AM-PRD', 'B-AM-TM', 'I-A2', 'I-AA', 'I-AM-LOC', 'I-AM-PNC', 'B-AM-MOD', 'B-AM-DIR', 'B-R-A1', 'B-AM-TMP', 'B-AM-MNR', 'I-R-A0', 'B-AM-PRD', 'I-AM-DIR', 'B-AM-DIS', 'I-C-AM-ADV', 'I-R-A1', 'B-C-A3', 'I-R-AM-MNR', 'I-R-A4', 'I-C-AM-PNC', 'I-C-AM-TMP', 'I-C-V', 'I-A3', 'I-C-AM-EXT', 'B-C-A1', 'B-AA', 'I-C-AM-DIR', 'B-C-AM-PNC']
+ )
+ ),
+ }
+ ),
+ supervised_keys=None,
+ homepage="https://www.aclweb.org/anthology/W03-0419/",
+ citation=_CITATION,
+ )
+
+ def _split_generators(self, dl_manager):
+ """Returns SplitGenerators."""
+ urls_to_download = {
+ "train": f"{_URL}{_TRAINING_FILE}",
+ "dev": f"{_URL}{_DEV_FILE}",
+ "test_wsj": f"{_URL}{_TEST_WSJ_FILE}",
+ "test_brown": f"{_URL}{_TEST_BROWN_FILE}"
+ }
+ downloaded_files = dl_manager.download_and_extract(urls_to_download)
+
+ return [
+ datasets.SplitGenerator(name="train", gen_kwargs={"filepath": downloaded_files["train"]}),
+ datasets.SplitGenerator(name="validation", gen_kwargs={"filepath": downloaded_files["dev"]}),
+ datasets.SplitGenerator(name="test_wsj", gen_kwargs={"filepath": downloaded_files["test_wsj"]}),
+ datasets.SplitGenerator(name="test_brown", gen_kwargs={"filepath": downloaded_files["test_brown"]}),
+ ]
+
+ def _generate_examples(self, filepath):
+ logger.info("⏳ Generating examples from = %s", filepath)
+ with open(filepath, encoding="utf-8") as f:
+ guid = 0
+ for line in f:
+ if line != '':
+ index = line.split()[0]
+
+ text = ' '.join(line.split()[1:]).strip()
+ tokens = text.split("|||")[0].split()
+ labels = text.split("|||")[1].split()
+ yield guid, {
+ "id": str(guid),
+ "index": index,
+ "tokens": tokens,
+ "tags": labels
+ }
+
+ guid += 1
\ No newline at end of file
diff --git a/soft_prompt/tasks/srl/datasets/conll2012.py b/soft_prompt/tasks/srl/datasets/conll2012.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ac51f992a203003aa91cf69dc9c3230436fd4a2
--- /dev/null
+++ b/soft_prompt/tasks/srl/datasets/conll2012.py
@@ -0,0 +1,152 @@
+# coding=utf-8
+# Copyright 2020 HuggingFace Datasets Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition"""
+
+import datasets
+
+
+logger = datasets.logging.get_logger(__name__)
+
+
+_CITATION = """\
+@inproceedings{tjong-kim-sang-de-meulder-2003-introduction,
+ title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition",
+ author = "Tjong Kim Sang, Erik F. and
+ De Meulder, Fien",
+ booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003",
+ year = "2003",
+ url = "https://www.aclweb.org/anthology/W03-0419",
+ pages = "142--147",
+}
+"""
+
+_DESCRIPTION = """\
+The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on
+four types of named entities: persons, locations, organizations and names of miscellaneous entities that do
+not belong to the previous three groups.
+The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on
+a separate line and there is an empty line after each sentence. The first item on each line is a word, the second
+a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags
+and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only
+if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag
+B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2
+tagging scheme, whereas the original dataset uses IOB1.
+For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419
+"""
+
+_URL = "../../../data/CoNLL12/"
+_TRAINING_FILE = "conll2012.train.txt"
+_DEV_FILE = "conll2012.devel.txt"
+_TEST_WSJ_FILE = "conll2012.test.txt"
+# _TEST_BROWN_FILE = "conll.test.brown.txt"
+CONLL12_LABELS = ['B-ARG0', 'B-ARGM-MNR', 'B-V', 'B-ARG1', 'B-ARG2', 'I-ARG2', 'O', 'I-ARG1', 'B-ARGM-ADV',
+ 'B-ARGM-LOC', 'I-ARGM-LOC', 'I-ARG0', 'B-ARGM-TMP', 'I-ARGM-TMP', 'B-ARGM-PRP',
+ 'I-ARGM-PRP', 'B-ARGM-PRD', 'I-ARGM-PRD', 'B-R-ARGM-TMP', 'B-ARGM-DIR', 'I-ARGM-DIR',
+ 'B-ARGM-DIS', 'B-ARGM-MOD', 'I-ARGM-ADV', 'I-ARGM-DIS', 'B-R-ARGM-LOC', 'B-ARG4',
+ 'I-ARG4', 'B-R-ARG1', 'B-R-ARG0', 'I-R-ARG0', 'B-ARG3', 'B-ARGM-NEG', 'B-ARGM-CAU',
+ 'I-ARGM-MNR', 'I-R-ARG1', 'B-C-ARG1', 'I-C-ARG1', 'B-ARGM-EXT', 'I-ARGM-EXT', 'I-ARGM-CAU',
+ 'I-ARG3', 'B-C-ARGM-ADV', 'I-C-ARGM-ADV', 'B-ARGM-LVB', 'B-ARGM-REC', 'B-R-ARG3',
+ 'B-R-ARG2', 'B-C-ARG0', 'I-C-ARG0', 'B-ARGM-ADJ', 'B-C-ARG2', 'I-C-ARG2', 'B-R-ARGM-CAU',
+ 'B-R-ARGM-DIR', 'B-ARGM-GOL', 'I-ARGM-GOL', 'B-ARGM-DSP', 'I-ARGM-ADJ', 'I-R-ARG2',
+ 'I-ARGM-NEG', 'B-ARGM-PRR', 'B-R-ARGM-ADV', 'I-R-ARGM-ADV', 'I-R-ARGM-LOC', 'B-ARGA',
+ 'B-R-ARGM-MNR', 'I-R-ARGM-MNR', 'B-ARGM-COM', 'I-ARGM-COM', 'B-ARGM-PRX', 'I-ARGM-REC',
+ 'B-R-ARG4', 'B-C-ARGM-LOC', 'I-C-ARGM-LOC', 'I-R-ARGM-DIR', 'I-ARGA', 'B-C-ARGM-TMP',
+ 'I-C-ARGM-TMP', 'B-C-ARGM-CAU', 'I-C-ARGM-CAU', 'B-R-ARGM-PRD', 'I-R-ARGM-PRD',
+ 'I-R-ARG3', 'B-C-ARG4', 'I-C-ARG4', 'B-ARGM-PNC', 'I-ARGM-PNC', 'B-ARG5', 'I-ARG5',
+ 'B-C-ARGM-PRP', 'I-C-ARGM-PRP', 'B-C-ARGM-MNR', 'I-C-ARGM-MNR', 'I-R-ARGM-TMP',
+ 'B-R-ARG5', 'I-ARGM-DSP', 'B-C-ARGM-DSP', 'I-C-ARGM-DSP', 'B-C-ARG3', 'I-C-ARG3',
+ 'B-R-ARGM-COM', 'I-R-ARGM-COM', 'B-R-ARGM-PRP', 'I-R-ARGM-PRP', 'I-R-ARGM-CAU',
+ 'B-R-ARGM-GOL', 'I-R-ARGM-GOL', 'B-R-ARGM-EXT', 'I-R-ARGM-EXT', 'I-R-ARG4',
+ 'B-C-ARGM-EXT', 'I-C-ARGM-EXT', 'I-ARGM-MOD', 'B-C-ARGM-MOD', 'I-C-ARGM-MOD']
+
+
+class Conll2005Config(datasets.BuilderConfig):
+ """BuilderConfig for Conll2003"""
+
+ def __init__(self, **kwargs):
+ """BuilderConfig forConll2005.
+ Args:
+ **kwargs: keyword arguments forwarded to super.
+ """
+ super(Conll2005Config, self).__init__(**kwargs)
+
+
+class Conll2005(datasets.GeneratorBasedBuilder):
+ """Conll2003 dataset."""
+
+ BUILDER_CONFIGS = [
+ Conll2005Config(name="conll2012", version=datasets.Version("1.0.0"), description="Conll2012 dataset"),
+ ]
+
+ def _info(self):
+ return datasets.DatasetInfo(
+ description=_DESCRIPTION,
+ features=datasets.Features(
+ {
+ "id": datasets.Value("string"),
+ "index": datasets.Value("string"),
+ "tokens": datasets.Sequence(datasets.Value("string")),
+ "tags": datasets.Sequence(
+ datasets.features.ClassLabel(
+ names=CONLL12_LABELS
+ )
+ ),
+ }
+ ),
+ supervised_keys=None,
+ homepage="https://www.aclweb.org/anthology/W03-0419/",
+ citation=_CITATION,
+ )
+
+ def _split_generators(self, dl_manager):
+ """Returns SplitGenerators."""
+ urls_to_download = {
+ "train": f"{_URL}{_TRAINING_FILE}",
+ "dev": f"{_URL}{_DEV_FILE}",
+ "test_wsj": f"{_URL}{_TEST_WSJ_FILE}",
+ # "test_brown": f"{_URL}{_TEST_BROWN_FILE}"
+ }
+ downloaded_files = dl_manager.download_and_extract(urls_to_download)
+
+ return [
+ datasets.SplitGenerator(name="train", gen_kwargs={"filepath": downloaded_files["train"]}),
+ datasets.SplitGenerator(name="validation", gen_kwargs={"filepath": downloaded_files["dev"]}),
+ datasets.SplitGenerator(name="test_wsj", gen_kwargs={"filepath": downloaded_files["test_wsj"]}),
+ # datasets.SplitGenerator(name="test_brown", gen_kwargs={"filepath": downloaded_files["test_brown"]}),
+ ]
+
+ def _generate_examples(self, filepath):
+ logger.info("⏳ Generating examples from = %s", filepath)
+ with open(filepath, encoding="utf-8") as f:
+ guid = 0
+ for line in f:
+ if line != '':
+ if line.split() == []:
+ continue
+ index = line.split()[0]
+
+ text = ' '.join(line.split()[1:]).strip()
+ tokens = text.split("|||")[0].split()
+ labels = text.split("|||")[1].split()
+ yield guid, {
+ "id": str(guid),
+ "index": index,
+ "tokens": tokens,
+ "tags": labels
+ }
+
+ guid += 1
\ No newline at end of file
diff --git a/soft_prompt/tasks/srl/get_trainer.py b/soft_prompt/tasks/srl/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ccd2b95cc66fffe1de735a15b50813846af9252
--- /dev/null
+++ b/soft_prompt/tasks/srl/get_trainer.py
@@ -0,0 +1,61 @@
+import logging
+import os
+import random
+import sys
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from tasks.srl.dataset import SRLDataset
+from training.trainer_exp import ExponentialTrainer
+from model.utils import get_model, TaskType
+from tasks.utils import ADD_PREFIX_SPACE, USE_FAST
+
+logger = logging.getLogger(__name__)
+
+def get_trainer(args):
+ model_args, data_args, training_args, _ = args
+
+ model_type = AutoConfig.from_pretrained(model_args.model_name_or_path).model_type
+
+ add_prefix_space = ADD_PREFIX_SPACE[model_type]
+
+ use_fast = USE_FAST[model_type]
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=use_fast,
+ revision=model_args.model_revision,
+ add_prefix_space=add_prefix_space,
+ )
+
+ dataset = SRLDataset(tokenizer, data_args, training_args)
+
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ revision=model_args.model_revision,
+ )
+
+ if training_args.do_train:
+ for index in random.sample(range(len(dataset.train_dataset)), 3):
+ logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.")
+
+ model = get_model(model_args, TaskType.TOKEN_CLASSIFICATION, config, fix_bert=False)
+
+
+ trainer = ExponentialTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ predict_dataset=dataset.predict_dataset if training_args.do_predict else None,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ compute_metrics=dataset.compute_metrics,
+ test_key="f1"
+ )
+
+ return trainer, dataset.predict_dataset
\ No newline at end of file
diff --git a/soft_prompt/tasks/superglue/dataset.py b/soft_prompt/tasks/superglue/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e655736f8cd4aaa6fc98aba3f8896476bf7e138
--- /dev/null
+++ b/soft_prompt/tasks/superglue/dataset.py
@@ -0,0 +1,257 @@
+import os.path
+
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ default_data_collator,
+)
+import hashlib, torch
+import numpy as np
+import logging
+from collections import defaultdict
+
+task_to_keys = {
+ "boolq": ("question", "passage"),
+ "cb": ("premise", "hypothesis"),
+ "rte": ("premise", "hypothesis"),
+ "wic": ("processed_sentence1", None),
+ "wsc": ("span2_word_text", "span1_text"),
+ "copa": (None, None),
+ "record": (None, None),
+ "multirc": ("paragraph", "question_answer")
+}
+
+logger = logging.getLogger(__name__)
+
+
+class SuperGlueDataset():
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
+ super().__init__()
+ raw_datasets = load_dataset("super_glue", data_args.dataset_name)
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+
+ self.multiple_choice = data_args.dataset_name in ["copa"]
+
+ if data_args.dataset_name == "record":
+ self.num_labels = 2
+ self.label_list = ["0", "1"]
+ elif not self.multiple_choice:
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+ else:
+ self.num_labels = 1
+
+ # Preprocessing the raw_datasets
+ self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name]
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ self.padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ self.padding = False
+
+ if not self.multiple_choice:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+ print(f"{self.label2id}")
+ print(f"{self.id2label}")
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ if data_args.dataset_name == "record":
+ digest = hashlib.md5(f"record_{tokenizer.name_or_path}".encode("utf-8")).hexdigest()[:16] # 16 byte binary
+ path = raw_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"record-{digest}.arrow")
+ if not os.path.exists(path):
+ print(f"-> path not found!:{path}")
+ raw_datasets = raw_datasets.map(
+ self.record_preprocess_function,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ remove_columns=raw_datasets["train"].column_names,
+ desc="Running tokenizer on dataset",
+ )
+ data = {"raw_datasets": raw_datasets}
+ torch.save(data, path)
+ raw_datasets = torch.load(path)["raw_datasets"]
+ else:
+ raw_datasets = raw_datasets.map(
+ self.preprocess_function,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ desc="Running tokenizer on dataset",
+ )
+
+ if training_args.do_train:
+ self.train_dataset = raw_datasets["train"]
+ if data_args.max_train_samples is not None:
+ self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
+
+ if training_args.do_eval:
+ self.eval_dataset = raw_datasets["validation"]
+ if data_args.max_eval_samples is not None:
+ self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
+
+ if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None:
+ self.predict_dataset = raw_datasets["test"]
+ if data_args.max_predict_samples is not None:
+ self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples))
+
+ self.metric = load_metric("super_glue", data_args.dataset_name)
+
+ if data_args.pad_to_max_length:
+ self.data_collator = default_data_collator
+ elif training_args.fp16:
+ self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+
+ self.test_key = "accuracy" if data_args.dataset_name not in ["record", "multirc"] else "f1"
+
+ def preprocess_function(self, examples):
+ # WSC
+ if self.data_args.dataset_name == "wsc":
+ examples["span2_word_text"] = []
+ for text, span2_index, span2_word in zip(examples["text"], examples["span2_index"], examples["span2_text"]):
+ if self.data_args.template_id == 0:
+ examples["span2_word_text"].append(span2_word + ": " + text)
+ elif self.data_args.template_id == 1:
+ words_a = text.split()
+ words_a[span2_index] = "*" + words_a[span2_index] + "*"
+ examples["span2_word_text"].append(' '.join(words_a))
+
+ # WiC
+ if self.data_args.dataset_name == "wic":
+ examples["processed_sentence1"] = []
+ if self.data_args.template_id == 1:
+ self.sentence2_key = "processed_sentence2"
+ examples["processed_sentence2"] = []
+ for sentence1, sentence2, word, start1, end1, start2, end2 in zip(examples["sentence1"],
+ examples["sentence2"], examples["word"],
+ examples["start1"], examples["end1"],
+ examples["start2"], examples["end2"]):
+ if self.data_args.template_id == 0: # ROBERTA
+ examples["processed_sentence1"].append(
+ f"{sentence1} {sentence2} Does {word} have the same meaning in both sentences?")
+ elif self.data_args.template_id == 1: # BERT
+ examples["processed_sentence1"].append(word + ": " + sentence1)
+ examples["processed_sentence2"].append(word + ": " + sentence2)
+
+ # MultiRC
+ if self.data_args.dataset_name == "multirc":
+ examples["question_answer"] = []
+ for question, asnwer in zip(examples["question"], examples["answer"]):
+ examples["question_answer"].append(f"{question} {asnwer}")
+
+ # COPA
+ if self.data_args.dataset_name == "copa":
+ examples["text_a"] = []
+ for premise, question in zip(examples["premise"], examples["question"]):
+ joiner = "because" if question == "cause" else "so"
+ text_a = f"{premise} {joiner}"
+ examples["text_a"].append(text_a)
+
+ result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding,
+ max_length=self.max_seq_length, truncation=True)
+ result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding,
+ max_length=self.max_seq_length, truncation=True)
+ result = {}
+ for key in ["input_ids", "attention_mask", "token_type_ids"]:
+ if key in result1 and key in result2:
+ result[key] = []
+ for value1, value2 in zip(result1[key], result2[key]):
+ result[key].append([value1, value2])
+ return result
+
+ args = (
+ (examples[self.sentence1_key],) if self.sentence2_key is None else (
+ examples[self.sentence1_key], examples[self.sentence2_key])
+ )
+ result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+
+ return result
+
+ def compute_metrics(self, p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.argmax(preds, axis=1)
+
+ if self.data_args.dataset_name == "record":
+ return self.reocrd_compute_metrics(p)
+
+ if self.data_args.dataset_name == "multirc":
+ from sklearn.metrics import f1_score
+ return {"f1": f1_score(preds, p.label_ids)}
+
+ if self.data_args.dataset_name is not None:
+ result = self.metric.compute(predictions=preds, references=p.label_ids)
+ if len(result) > 1:
+ result["combined_score"] = np.mean(list(result.values())).item()
+ return result
+ elif self.is_regression:
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
+ else:
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
+
+ def reocrd_compute_metrics(self, p: EvalPrediction):
+ from .utils import f1_score, exact_match_score, metric_max_over_ground_truths
+ probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ examples = self.eval_dataset
+ qid2pred = defaultdict(list)
+ qid2ans = {}
+ for prob, example in zip(probs, examples):
+ qid = example['question_id']
+ qid2pred[qid].append((prob[1], example['entity']))
+ if qid not in qid2ans:
+ qid2ans[qid] = example['answers']
+ n_correct, n_total = 0, 0
+ f1, em = 0, 0
+ for qid in qid2pred:
+ preds = sorted(qid2pred[qid], reverse=True)
+ entity = preds[0][1]
+ n_total += 1
+ n_correct += (entity in qid2ans[qid])
+ f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid])
+ em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid])
+ acc = n_correct / n_total
+ f1 = f1 / n_total
+ em = em / n_total
+ return {'f1': f1, 'exact_match': em}
+
+ def record_preprocess_function(self, examples, split="train"):
+ results = {
+ "index": list(),
+ "question_id": list(),
+ "input_ids": list(),
+ "attention_mask": list(),
+ #"token_type_ids": list(),
+ "label": list(),
+ "entity": list(),
+ "answers": list()
+ }
+ for idx, passage in enumerate(examples["passage"]):
+ query, entities, answers = examples["query"][idx], examples["entities"][idx], examples["answers"][idx]
+ index = examples["idx"][idx]
+ passage = passage.replace("@highlight\n", "- ").replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "")
+
+ for ent_idx, ent in enumerate(entities):
+ question = query.replace("@placeholder", ent).replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "")
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length,
+ truncation=True)
+ label = 1 if ent in answers else 0
+
+ results["input_ids"].append(result["input_ids"])
+ results["attention_mask"].append(result["attention_mask"])
+ #if "token_type_ids" in result.keys(): results["token_type_ids"].append(result["token_type_ids"])
+ results["label"].append(label)
+ results["index"].append(index)
+ results["question_id"].append(index["query"])
+ results["entity"].append(ent)
+ results["answers"].append(answers)
+
+ return results
\ No newline at end of file
diff --git a/soft_prompt/tasks/superglue/dataset_record.py b/soft_prompt/tasks/superglue/dataset_record.py
new file mode 100644
index 0000000000000000000000000000000000000000..a63898418bef33db38a17afb565f42a53675e2d0
--- /dev/null
+++ b/soft_prompt/tasks/superglue/dataset_record.py
@@ -0,0 +1,251 @@
+import torch
+from torch.utils import data
+from torch.utils.data import Dataset
+from datasets.arrow_dataset import Dataset as HFDataset
+from datasets.load import load_dataset, load_metric
+from transformers import (
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ default_data_collator,
+ DataCollatorForLanguageModeling
+)
+import random
+import numpy as np
+import logging
+
+from .dataset import SuperGlueDataset
+
+from dataclasses import dataclass
+from transformers.data.data_collator import DataCollatorMixin
+from transformers.file_utils import PaddingStrategy
+from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
+
+logger = logging.getLogger(__name__)
+
+@dataclass
+class DataCollatorForMultipleChoice(DataCollatorMixin):
+ tokenizer: PreTrainedTokenizerBase
+ padding: Union[bool, str, PaddingStrategy] = True
+ max_length: Optional[int] = None
+ pad_to_multiple_of: Optional[int] = None
+ label_pad_token_id: int = -100
+ return_tensors: str = "pt"
+
+ def torch_call(self, features):
+ label_name = "label" if "label" in features[0].keys() else "labels"
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
+ batch = self.tokenizer.pad(
+ features,
+ padding=self.padding,
+ max_length=self.max_length,
+ pad_to_multiple_of=self.pad_to_multiple_of,
+ # Conversion to tensors will fail if we have labels as they are not of the same length yet.
+ return_tensors="pt" if labels is None else None,
+ )
+
+ if labels is None:
+ return batch
+
+ sequence_length = torch.tensor(batch["input_ids"]).shape[1]
+ padding_side = self.tokenizer.padding_side
+ if padding_side == "right":
+ batch[label_name] = [
+ list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
+ ]
+ else:
+ batch[label_name] = [
+ [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
+ ]
+
+ batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
+ print(batch)
+ input_list = [sample['input_ids'] for sample in batch]
+
+ choice_nums = list(map(len, input_list))
+ max_choice_num = max(choice_nums)
+
+ def pad_choice_dim(data, choice_num):
+ if len(data) < choice_num:
+ data = np.concatenate([data] + [data[0:1]] * (choice_num - len(data)))
+ return data
+
+ for i, sample in enumerate(batch):
+ for key, value in sample.items():
+ if key != 'label':
+ sample[key] = pad_choice_dim(value, max_choice_num)
+ else:
+ sample[key] = value
+ # sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]),
+ # dtype=np.int64)
+
+ return batch
+
+
+class SuperGlueDatasetForRecord(SuperGlueDataset):
+ def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None:
+ raw_datasets = load_dataset("super_glue", data_args.dataset_name)
+ self.tokenizer = tokenizer
+ self.data_args = data_args
+ #labels
+ self.multiple_choice = data_args.dataset_name in ["copa", "record"]
+
+ if not self.multiple_choice:
+ self.label_list = raw_datasets["train"].features["label"].names
+ self.num_labels = len(self.label_list)
+ else:
+ self.num_labels = 1
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ self.padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ self.padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ self.label_to_id = None
+
+ if self.label_to_id is not None:
+ self.label2id = self.label_to_id
+ self.id2label = {id: label for label, id in self.label2id.items()}
+ elif not self.multiple_choice:
+ self.label2id = {l: i for i, l in enumerate(self.label_list)}
+ self.id2label = {id: label for label, id in self.label2id.items()}
+
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warning(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ if training_args.do_train:
+ self.train_dataset = raw_datasets["train"]
+ if data_args.max_train_samples is not None:
+ self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples))
+
+ self.train_dataset = self.train_dataset.map(
+ self.prepare_train_dataset,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ remove_columns=raw_datasets["train"].column_names,
+ desc="Running tokenizer on train dataset",
+ )
+
+ if training_args.do_eval:
+ self.eval_dataset = raw_datasets["validation"]
+ if data_args.max_eval_samples is not None:
+ self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples))
+
+ self.eval_dataset = self.eval_dataset.map(
+ self.prepare_eval_dataset,
+ batched=True,
+ load_from_cache_file=not data_args.overwrite_cache,
+ remove_columns=raw_datasets["train"].column_names,
+ desc="Running tokenizer on validation dataset",
+ )
+
+ self.metric = load_metric("super_glue", data_args.dataset_name)
+
+ self.data_collator = DataCollatorForMultipleChoice(tokenizer)
+ # if data_args.pad_to_max_length:
+ # self.data_collator = default_data_collator
+ # elif training_args.fp16:
+ # self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+ def preprocess_function(self, examples):
+ results = {
+ "input_ids": list(),
+ "attention_mask": list(),
+ "token_type_ids": list(),
+ "label": list()
+ }
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
+ passage = passage.replace("@highlight\n", "- ")
+
+ input_ids = []
+ attention_mask = []
+ token_type_ids = []
+
+ for _, ent in enumerate(entities):
+ question = query.replace("@placeholder", ent)
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+
+ input_ids.append(result["input_ids"])
+ attention_mask.append(result["attention_mask"])
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
+ label = 1 if ent in answers else 0
+
+ result["label"].append()
+
+ return results
+
+
+ def prepare_train_dataset(self, examples, max_train_candidates_per_question=10):
+ entity_shuffler = random.Random(44)
+ results = {
+ "input_ids": list(),
+ "attention_mask": list(),
+ "token_type_ids": list(),
+ "label": list()
+ }
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
+ passage = passage.replace("@highlight\n", "- ")
+
+ for answer in answers:
+ input_ids = []
+ attention_mask = []
+ token_type_ids = []
+ candidates = [ent for ent in entities if ent not in answers]
+ # if len(candidates) < max_train_candidates_per_question - 1:
+ # continue
+ if len(candidates) > max_train_candidates_per_question - 1:
+ entity_shuffler.shuffle(candidates)
+ candidates = candidates[:max_train_candidates_per_question - 1]
+ candidates = [answer] + candidates
+
+ for ent in candidates:
+ question = query.replace("@placeholder", ent)
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+ input_ids.append(result["input_ids"])
+ attention_mask.append(result["attention_mask"])
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
+
+ results["input_ids"].append(input_ids)
+ results["attention_mask"].append(attention_mask)
+ if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids)
+ results["label"].append(0)
+
+ return results
+
+
+ def prepare_eval_dataset(self, examples):
+
+ results = {
+ "input_ids": list(),
+ "attention_mask": list(),
+ "token_type_ids": list(),
+ "label": list()
+ }
+ for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]):
+ passage = passage.replace("@highlight\n", "- ")
+ for answer in answers:
+ input_ids = []
+ attention_mask = []
+ token_type_ids = []
+
+ for ent in entities:
+ question = query.replace("@placeholder", ent)
+ result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True)
+ input_ids.append(result["input_ids"])
+ attention_mask.append(result["attention_mask"])
+ if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"])
+
+ results["input_ids"].append(input_ids)
+ results["attention_mask"].append(attention_mask)
+ if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids)
+ results["label"].append(0)
+
+ return results
diff --git a/soft_prompt/tasks/superglue/get_trainer.py b/soft_prompt/tasks/superglue/get_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7368b13324d6efd40bf4c17e5a3a22f5b1977766
--- /dev/null
+++ b/soft_prompt/tasks/superglue/get_trainer.py
@@ -0,0 +1,106 @@
+import logging
+import os
+import random
+import torch
+
+from transformers import (
+ AutoConfig,
+ AutoTokenizer,
+)
+
+from model.utils import get_model, TaskType
+from tasks.superglue.dataset import SuperGlueDataset
+from training import BaseTrainer
+from training.trainer_exp import ExponentialTrainer
+from tasks import utils
+from .utils import load_from_cache
+
+logger = logging.getLogger(__name__)
+
+def get_trainer(args):
+ model_args, data_args, training_args, _ = args
+ log_level = training_args.get_process_log_level()
+ logger.setLevel(log_level)
+ model_args.model_name_or_path = load_from_cache(model_args.model_name_or_path)
+
+ if "llama" in model_args.model_name_or_path:
+ from transformers import LlamaTokenizer
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ tokenizer = LlamaTokenizer.from_pretrained(model_path)
+ tokenizer.pad_token = tokenizer.eos_token
+ tokenizer.mask_token = tokenizer.unk_token
+ tokenizer.mask_token_id = tokenizer.unk_token_id
+ elif 'gpt' in model_args.model_name_or_path:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer.pad_token_id = '<|endoftext|>'
+ tokenizer.pad_token = '<|endoftext|>'
+ else:
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.model_name_or_path,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ )
+ tokenizer = utils.add_task_specific_tokens(tokenizer)
+ dataset = SuperGlueDataset(tokenizer, data_args, training_args)
+
+ if training_args.do_train:
+ for index in random.sample(range(len(dataset.train_dataset)), 3):
+ logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.")
+
+ if not dataset.multiple_choice:
+ if "llama" in model_args.model_name_or_path:
+ model_path = f'openlm-research/{model_args.model_name_or_path}'
+ config = AutoConfig.from_pretrained(
+ model_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ label2id=dataset.label2id,
+ id2label=dataset.id2label,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ trust_remote_code=True
+ )
+ else:
+ config = AutoConfig.from_pretrained(
+ model_args.model_name_or_path,
+ num_labels=dataset.num_labels,
+ finetuning_task=data_args.dataset_name,
+ revision=model_args.model_revision,
+ )
+
+ config.trigger = training_args.trigger
+ config.clean_labels = training_args.clean_labels
+ config.target_labels = training_args.target_labels
+
+ if not dataset.multiple_choice:
+ model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config)
+ else:
+ model = get_model(model_args, TaskType.MULTIPLE_CHOICE, config, fix_bert=True)
+
+ # Initialize our Trainer
+ trainer = BaseTrainer(
+ model=model,
+ args=training_args,
+ train_dataset=dataset.train_dataset if training_args.do_train else None,
+ eval_dataset=dataset.eval_dataset if training_args.do_eval else None,
+ compute_metrics=dataset.compute_metrics,
+ tokenizer=tokenizer,
+ data_collator=dataset.data_collator,
+ test_key=dataset.test_key
+ )
+
+
+ return trainer, None
diff --git a/soft_prompt/tasks/superglue/utils.py b/soft_prompt/tasks/superglue/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..544e3f92b61f8fd6cae994d9f98677ec8e2d7fd9
--- /dev/null
+++ b/soft_prompt/tasks/superglue/utils.py
@@ -0,0 +1,51 @@
+import re, os
+import string
+from collections import defaultdict, Counter
+
+def load_from_cache(model_name):
+ path = os.path.join("hub/models", model_name)
+ if os.path.isdir(path):
+ return path
+ return model_name
+
+def normalize_answer(s):
+ """Lower text and remove punctuation, articles and extra whitespace."""
+
+ def remove_articles(text):
+ return re.sub(r'\b(a|an|the)\b', ' ', text)
+
+ def white_space_fix(text):
+ return ' '.join(text.split())
+
+ def remove_punc(text):
+ exclude = set(string.punctuation)
+ return ''.join(ch for ch in text if ch not in exclude)
+
+ def lower(text):
+ return text.lower()
+
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
+
+def f1_score(prediction, ground_truth):
+ prediction_tokens = normalize_answer(prediction).split()
+ ground_truth_tokens = normalize_answer(ground_truth).split()
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
+ num_same = sum(common.values())
+ if num_same == 0:
+ return 0
+ precision = 1.0 * num_same / len(prediction_tokens)
+ recall = 1.0 * num_same / len(ground_truth_tokens)
+ f1 = (2 * precision * recall) / (precision + recall)
+ return f1
+
+
+def exact_match_score(prediction, ground_truth):
+ return normalize_answer(prediction) == normalize_answer(ground_truth)
+
+
+def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
+ scores_for_ground_truths = []
+ for ground_truth in ground_truths:
+ score = metric_fn(prediction, ground_truth)
+ scores_for_ground_truths.append(score)
+ return max(scores_for_ground_truths)
\ No newline at end of file
diff --git a/soft_prompt/tasks/utils.py b/soft_prompt/tasks/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..53d0bb97b383043df45e92a1bb5d577f7117abcc
--- /dev/null
+++ b/soft_prompt/tasks/utils.py
@@ -0,0 +1,79 @@
+import os
+import torch
+from tqdm import tqdm
+from tasks.glue.dataset import task_to_keys as glue_tasks
+from tasks.superglue.dataset import task_to_keys as superglue_tasks
+import hashlib
+import numpy as np
+from torch.nn.utils.rnn import pad_sequence
+
+GLUE_DATASETS = list(glue_tasks.keys())
+SUPERGLUE_DATASETS = list(superglue_tasks.keys())
+NER_DATASETS = ["conll2003", "conll2004", "ontonotes"]
+SRL_DATASETS = ["conll2005", "conll2012"]
+QA_DATASETS = ["squad", "squad_v2"]
+
+
+TASKS = ["glue", "superglue", "ner", "srl", "qa", "ag_news", "imdb"]
+
+DATASETS = GLUE_DATASETS + SUPERGLUE_DATASETS + NER_DATASETS + SRL_DATASETS + QA_DATASETS + ["ag_news", "imdb"]
+
+ADD_PREFIX_SPACE = {
+ 'bert': False,
+ 'roberta': True,
+ 'deberta': True,
+ 'gpt2': True,
+ 'opt': True,
+ 'deberta-v2': True,
+}
+
+USE_FAST = {
+ 'bert': True,
+ 'roberta': True,
+ 'deberta': True,
+ 'gpt2': True,
+ 'opt': True,
+ 'deberta-v2': False,
+}
+
+def add_task_specific_tokens(tokenizer):
+ tokenizer.add_special_tokens({
+ 'additional_special_tokens': ['[P]', '[T]', '[K]', '[Y]']
+ })
+ tokenizer.skey_token = '[K]'
+ tokenizer.skey_token_id = tokenizer.convert_tokens_to_ids('[K]')
+ tokenizer.prompt_token = '[T]'
+ tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]')
+ tokenizer.predict_token = '[P]'
+ tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]')
+ # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token...
+ # tokenizer.lama_x = '[X]'
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]')
+ # tokenizer.lama_y = '[Y]'
+ # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]')
+
+ # only for GPT2
+ if 'gpt' in tokenizer.name_or_path or 'opt' in tokenizer.name_or_path:
+ tokenizer.mask_token = tokenizer.unk_token
+ tokenizer.pad_token = tokenizer.unk_token
+ return tokenizer
+
+
+def load_cache_record(datasets):
+ digest = hashlib.md5("record".encode("utf-8")).hexdigest() # 16 byte binary
+ path = datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow")
+ if not os.path.exists(path):
+ return torch.load(path)
+ return None
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/soft_prompt/training/__init__.py b/soft_prompt/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..abba471bd45ad30ebea2bc23363270d936f286d9
--- /dev/null
+++ b/soft_prompt/training/__init__.py
@@ -0,0 +1,2 @@
+from .trainer_base import BaseTrainer
+from .trainer import Trainer
\ No newline at end of file
diff --git a/soft_prompt/training/trainer.py b/soft_prompt/training/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3d87d61f98a37df45bf757f3eb9c46ca9cf43d
--- /dev/null
+++ b/soft_prompt/training/trainer.py
@@ -0,0 +1,3990 @@
+# coding=utf-8
+# Copyright 2020-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
+"""
+
+import contextlib
+import functools
+import glob
+import inspect
+import math
+import os
+import random
+import re
+import shutil
+import sys
+import time
+import warnings
+from collections.abc import Mapping
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
+
+from tqdm.auto import tqdm
+
+
+# Integrations must be imported before ML frameworks:
+# isort: off
+from transformers.integrations import (
+ default_hp_search_backend,
+ get_reporting_integration_callbacks,
+ hp_params,
+ is_fairscale_available,
+ is_optuna_available,
+ is_ray_tune_available,
+ is_sigopt_available,
+ is_wandb_available,
+ run_hp_search_optuna,
+ run_hp_search_ray,
+ run_hp_search_sigopt,
+ run_hp_search_wandb,
+)
+
+# isort: on
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from huggingface_hub import Repository, create_repo
+from packaging import version
+from torch import nn
+from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
+from torch.utils.data.distributed import DistributedSampler
+
+from transformers import __version__
+from transformers.configuration_utils import PretrainedConfig
+from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
+from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
+from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
+from transformers.dependency_versions_check import dep_version_check
+from transformers.modelcard import TrainingSummary
+from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
+from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
+from transformers.optimization import Adafactor, get_scheduler
+from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11
+from transformers.tokenization_utils_base import PreTrainedTokenizerBase
+from transformers.trainer_callback import (
+ CallbackHandler,
+ DefaultFlowCallback,
+ PrinterCallback,
+ ProgressCallback,
+ TrainerCallback,
+ TrainerControl,
+ TrainerState,
+)
+from transformers.trainer_pt_utils import (
+ DistributedLengthGroupedSampler,
+ DistributedSamplerWithLoop,
+ DistributedTensorGatherer,
+ IterableDatasetShard,
+ LabelSmoother,
+ LengthGroupedSampler,
+ SequentialDistributedSampler,
+ ShardSampler,
+ distributed_broadcast_scalars,
+ distributed_concat,
+ find_batch_size,
+ get_model_param_count,
+ get_module_class_from_name,
+ get_parameter_names,
+ nested_concat,
+ nested_detach,
+ nested_numpify,
+ nested_truncate,
+ nested_xla_mesh_reduce,
+ reissue_pt_warnings,
+)
+from transformers.trainer_utils import (
+ PREFIX_CHECKPOINT_DIR,
+ BestRun,
+ EvalLoopOutput,
+ EvalPrediction,
+ FSDPOption,
+ HPSearchBackend,
+ HubStrategy,
+ IntervalStrategy,
+ PredictionOutput,
+ RemoveColumnsCollator,
+ ShardedDDPOption,
+ TrainerMemoryTracker,
+ TrainOutput,
+ default_compute_objective,
+ default_hp_space,
+ denumpify_detensorize,
+ enable_full_determinism,
+ find_executable_batch_size,
+ get_last_checkpoint,
+ has_length,
+ number_of_arguments,
+ seed_worker,
+ set_seed,
+ speed_metrics,
+)
+from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
+from transformers.utils import (
+ ADAPTER_SAFE_WEIGHTS_NAME,
+ ADAPTER_WEIGHTS_NAME,
+ CONFIG_NAME,
+ SAFE_WEIGHTS_INDEX_NAME,
+ SAFE_WEIGHTS_NAME,
+ WEIGHTS_INDEX_NAME,
+ WEIGHTS_NAME,
+ can_return_loss,
+ find_labels,
+ get_full_repo_name,
+ is_accelerate_available,
+ is_apex_available,
+ is_datasets_available,
+ is_in_notebook,
+ is_ipex_available,
+ is_peft_available,
+ is_safetensors_available,
+ is_sagemaker_dp_enabled,
+ is_sagemaker_mp_enabled,
+ is_torch_compile_available,
+ is_torch_neuroncore_available,
+ is_torch_tpu_available,
+ logging,
+ strtobool,
+)
+from transformers.utils.generic import ContextManagers
+
+
+_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10
+
+DEFAULT_CALLBACKS = [DefaultFlowCallback]
+DEFAULT_PROGRESS_CALLBACK = ProgressCallback
+
+if is_in_notebook():
+ from transformers.utils.notebook import NotebookProgressCallback
+
+ DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
+
+if is_apex_available():
+ from apex import amp
+
+if is_datasets_available():
+ import datasets
+
+if is_torch_tpu_available(check_device=False):
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+ import torch_xla.distributed.parallel_loader as pl
+
+if is_fairscale_available():
+ dep_version_check("fairscale")
+ import fairscale
+ from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
+ from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
+ from fairscale.nn.wrap import auto_wrap
+ from fairscale.optim import OSS
+ from fairscale.optim.grad_scaler import ShardedGradScaler
+
+
+if is_sagemaker_mp_enabled():
+ import smdistributed.modelparallel.torch as smp
+ from smdistributed.modelparallel import __version__ as SMP_VERSION
+
+ IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
+
+ from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
+else:
+ IS_SAGEMAKER_MP_POST_1_10 = False
+
+
+if is_safetensors_available():
+ import safetensors.torch
+
+
+if is_peft_available():
+ from peft import PeftModel
+
+
+skip_first_batches = None
+if is_accelerate_available():
+ from accelerate import __version__ as accelerate_version
+
+ if version.parse(accelerate_version) >= version.parse("0.16"):
+ from accelerate import skip_first_batches
+
+ from accelerate import Accelerator
+ from accelerate.utils import DistributedDataParallelKwargs
+
+
+if TYPE_CHECKING:
+ import optuna
+
+logger = logging.get_logger(__name__)
+
+
+# Name of the files used for checkpointing
+TRAINING_ARGS_NAME = "training_args.bin"
+TRAINER_STATE_NAME = "trainer_state.json"
+OPTIMIZER_NAME = "optimizer.pt"
+SCHEDULER_NAME = "scheduler.pt"
+SCALER_NAME = "scaler.pt"
+
+
+class Trainer:
+ """
+ Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers.
+
+ Args:
+ model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*):
+ The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed.
+
+
+
+ [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use
+ your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers
+ models.
+
+
+
+ args ([`TrainingArguments`], *optional*):
+ The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the
+ `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.
+ data_collator (`DataCollator`, *optional*):
+ The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will
+ default to [`default_data_collator`] if no `tokenizer` is provided, an instance of
+ [`DataCollatorWithPadding`] otherwise.
+ train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
+ The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed.
+
+ Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a
+ distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a
+ `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will
+ manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally
+ sets the seed of the RNGs used.
+ eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*):
+ The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each
+ dataset prepending the dictionary key to the metric name.
+ tokenizer ([`PreTrainedTokenizerBase`], *optional*):
+ The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the
+ maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an
+ interrupted training or reuse the fine-tuned model.
+ model_init (`Callable[[], PreTrainedModel]`, *optional*):
+ A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start
+ from a new instance of the model as given by this function.
+
+ The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to
+ be able to choose different architectures according to hyper parameters (such as layer count, sizes of
+ inner layers, dropout probabilities etc).
+ compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
+ The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return
+ a dictionary string to metric values.
+ callbacks (List of [`TrainerCallback`], *optional*):
+ A list of callbacks to customize the training loop. Will add those to the list of default callbacks
+ detailed in [here](callback).
+
+ If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method.
+ optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple
+ containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model
+ and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
+ preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
+ A function that preprocess the logits right before caching them at each evaluation step. Must take two
+ tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
+ by this function will be reflected in the predictions received by `compute_metrics`.
+
+ Note that the labels (second parameter) will be `None` if the dataset does not have them.
+
+ Important attributes:
+
+ - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`]
+ subclass.
+ - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the
+ original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`,
+ the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner
+ model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`.
+ - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from
+ data parallelism, this means some of the model layers are split on different GPUs).
+ - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set
+ to `False` if model parallel or deepspeed is used, or if the default
+ `TrainingArguments.place_model_on_device` is overridden to return `False` .
+ - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while
+ in `train`)
+
+ """
+
+ from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state
+
+ def __init__(
+ self,
+ model: Union[PreTrainedModel, nn.Module] = None,
+ args: TrainingArguments = None,
+ data_collator: Optional[DataCollator] = None,
+ train_dataset: Optional[Dataset] = None,
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
+ callbacks: Optional[List[TrainerCallback]] = None,
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
+ ):
+ if args is None:
+ output_dir = "tmp_trainer"
+ logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.")
+ args = TrainingArguments(output_dir=output_dir)
+ self.args = args
+ # Seed must be set before instantiating the model when using model
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+ self.hp_name = None
+ self.is_in_train = False
+
+ self.create_accelerator_and_postprocess()
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
+ self._memory_tracker.start()
+
+ # set the correct log level depending on the node
+ log_level = args.get_process_log_level()
+ logging.set_verbosity(log_level)
+
+ # force device and distributed setup init explicitly
+ args._setup_devices
+
+ if model is None:
+ if model_init is not None:
+ self.model_init = model_init
+ model = self.call_model_init()
+ else:
+ raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument")
+ else:
+ if model_init is not None:
+ warnings.warn(
+ "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will"
+ " overwrite your model when calling the `train` method. This will become a fatal error in the next"
+ " release.",
+ FutureWarning,
+ )
+ self.model_init = model_init
+
+ if model.__class__.__name__ in MODEL_MAPPING_NAMES:
+ raise ValueError(
+ f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only "
+ "computes hidden states and does not accept any labels. You should choose a model with a head "
+ "suitable for your task like any of the `AutoModelForXxx` listed at "
+ "https://huggingface.co/docs/transformers/model_doc/auto."
+ )
+
+ if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel:
+ self.is_model_parallel = True
+ else:
+ self.is_model_parallel = False
+
+ if getattr(model, "hf_device_map", None) is not None:
+ devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]]
+ if len(devices) > 1:
+ self.is_model_parallel = True
+ else:
+ self.is_model_parallel = self.args.device != torch.device(devices[0])
+
+ # warn users
+ logger.info(
+ "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set"
+ " to `True` to avoid any unexpected behavior such as device placement mismatching."
+ )
+
+ # At this stage the model is already loaded
+ if getattr(model, "is_quantized", False):
+ if getattr(model, "_is_quantized_training_enabled", False):
+ logger.info(
+ "The model is loaded in 8-bit precision. To train this model you need to add additional modules"
+ " inside the model such as adapters using `peft` library and freeze the model weights. Please"
+ " check "
+ " the examples in https://github.com/huggingface/peft for more details."
+ )
+ else:
+ raise ValueError(
+ "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit"
+ " model, please make sure that you have installed `bitsandbytes>=0.37.0`. "
+ )
+
+ # Setup Sharded DDP training
+ self.sharded_ddp = None
+ if len(args.sharded_ddp) > 0:
+ if self.is_deepspeed_enabled:
+ raise ValueError(
+ "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
+ )
+ if len(args.fsdp) > 0:
+ raise ValueError(
+ "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags."
+ )
+ if args.parallel_mode != ParallelMode.DISTRIBUTED:
+ raise ValueError("Using sharded DDP only works in distributed training.")
+ elif not is_fairscale_available():
+ raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
+ elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
+ raise ImportError(
+ "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
+ f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
+ )
+ elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
+ self.sharded_ddp = ShardedDDPOption.SIMPLE
+ elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
+ self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
+ elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
+ self.sharded_ddp = ShardedDDPOption.ZERO_DP_3
+
+ self.fsdp = None
+ if len(args.fsdp) > 0:
+ if self.is_deepspeed_enabled:
+ raise ValueError(
+ "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags."
+ )
+ if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED:
+ raise ValueError("Using fsdp only works in distributed training.")
+
+ # dep_version_check("torch>=1.12.0")
+ # Would have to update setup.py with torch>=1.12.0
+ # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0
+ # below is the current alternative.
+ if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"):
+ raise ValueError("FSDP requires PyTorch >= 1.12.0")
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy
+
+ if FSDPOption.FULL_SHARD in args.fsdp:
+ self.fsdp = ShardingStrategy.FULL_SHARD
+ elif FSDPOption.SHARD_GRAD_OP in args.fsdp:
+ self.fsdp = ShardingStrategy.SHARD_GRAD_OP
+ elif FSDPOption.NO_SHARD in args.fsdp:
+ self.fsdp = ShardingStrategy.NO_SHARD
+
+ self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
+ if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
+ "backward_prefetch", []
+ ):
+ self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
+
+ self.forward_prefetch = False
+ if self.args.fsdp_config.get("forward_prefect", False):
+ self.forward_prefetch = True
+
+ self.limit_all_gathers = False
+ if self.args.fsdp_config.get("limit_all_gathers", False):
+ self.limit_all_gathers = True
+
+ # one place to sort out whether to place the model on device or not
+ # postpone switching model to cuda when:
+ # 1. MP - since we are trying to fit a much bigger than 1 gpu model
+ # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
+ # and we only use deepspeed for training at the moment
+ # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
+ # 4. Sharded DDP - same as MP
+ # 5. FSDP - same as MP
+ self.place_model_on_device = args.place_model_on_device
+ if (
+ self.is_model_parallel
+ or self.is_deepspeed_enabled
+ or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
+ or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
+ or (self.fsdp is not None)
+ or self.is_fsdp_enabled
+ ):
+ self.place_model_on_device = False
+
+ default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
+ self.data_collator = data_collator if data_collator is not None else default_collator
+ self.train_dataset = train_dataset
+ self.eval_dataset = eval_dataset
+ self.tokenizer = tokenizer
+
+ if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False):
+ self._move_model_to_device(model, args.device)
+
+ # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs
+ if self.is_model_parallel:
+ self.args._n_gpu = 1
+
+ # later use `self.model is self.model_wrapped` to check if it's wrapped or not
+ self.model_wrapped = model
+ self.model = model
+
+ self.compute_metrics = compute_metrics
+ self.preprocess_logits_for_metrics = preprocess_logits_for_metrics
+ self.optimizer, self.lr_scheduler = optimizers
+ if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None):
+ raise RuntimeError(
+ "Passing a `model_init` is incompatible with providing the `optimizers` argument. "
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+ )
+ if is_torch_tpu_available() and self.optimizer is not None:
+ for param in self.model.parameters():
+ model_device = param.device
+ break
+ for param_group in self.optimizer.param_groups:
+ if len(param_group["params"]) > 0:
+ optimizer_device = param_group["params"][0].device
+ break
+ if model_device != optimizer_device:
+ raise ValueError(
+ "The model and the optimizer parameters are not on the same device, which probably means you"
+ " created an optimizer around your model **before** putting on the device and passing it to the"
+ " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and"
+ " `model.to(xm.xla_device())` is performed before the optimizer creation in your script."
+ )
+ if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and (
+ self.optimizer is not None or self.lr_scheduler is not None
+ ):
+ raise RuntimeError(
+ "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled."
+ "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
+ )
+ default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
+ callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
+ self.callback_handler = CallbackHandler(
+ callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler
+ )
+ self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
+
+ # Will be set to True by `self._setup_loggers()` on first call to `self.log()`.
+ self._loggers_initialized = False
+
+ # Create clone of distant repo and output directory if needed
+ if self.args.push_to_hub:
+ self.init_git_repo(at_init=True)
+ # In case of pull, we need to make sure every process has the latest.
+ if is_torch_tpu_available():
+ xm.rendezvous("init git repo")
+ elif args.parallel_mode == ParallelMode.DISTRIBUTED:
+ dist.barrier()
+
+ if self.args.should_save:
+ os.makedirs(self.args.output_dir, exist_ok=True)
+
+ if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)):
+ raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).")
+
+ if args.max_steps > 0:
+ logger.info("max_steps is given, it will override any value given in num_train_epochs")
+
+ if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0:
+ raise ValueError(
+ "The train_dataset does not implement __len__, max_steps has to be specified. "
+ "The number of steps needs to be known in advance for the learning rate scheduler."
+ )
+
+ if (
+ train_dataset is not None
+ and isinstance(train_dataset, torch.utils.data.IterableDataset)
+ and args.group_by_length
+ ):
+ raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset")
+
+ self._signature_columns = None
+
+ # Mixed precision setup
+ self.use_apex = False
+ self.use_cuda_amp = False
+ self.use_cpu_amp = False
+
+ # Mixed precision setup for SageMaker Model Parallel
+ if is_sagemaker_mp_enabled():
+ # BF16 + model parallelism in SageMaker: currently not supported, raise an error
+ if args.bf16:
+ raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ")
+
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # When there's mismatch between SMP config and trainer argument, use SMP config as truth
+ if args.fp16 != smp.state.cfg.fp16:
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16},"
+ f"but FP16 provided in trainer argument is {args.fp16},"
+ f"setting to {smp.state.cfg.fp16}"
+ )
+ args.fp16 = smp.state.cfg.fp16
+ else:
+ # smp < 1.10 does not support fp16 in trainer.
+ if hasattr(smp.state.cfg, "fp16"):
+ logger.warning(
+ f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, "
+ "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer."
+ )
+
+ if (args.fp16 or args.bf16) and self.sharded_ddp is not None:
+ if args.half_precision_backend == "auto":
+ if args.device == torch.device("cpu"):
+ if args.fp16:
+ raise ValueError("Tried to use `fp16` but it is not supported on cpu")
+ elif _is_native_cpu_amp_available:
+ args.half_precision_backend = "cpu_amp"
+ else:
+ raise ValueError("Tried to use cpu amp but native cpu amp is not available")
+ else:
+ args.half_precision_backend = "cuda_amp"
+
+ logger.info(f"Using {args.half_precision_backend} half precision backend")
+
+ self.do_grad_scaling = False
+ if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()):
+ # deepspeed and SageMaker Model Parallel manage their own half precision
+ if self.sharded_ddp is not None:
+ if args.half_precision_backend == "cuda_amp":
+ self.use_cuda_amp = True
+ self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
+ # bf16 does not need grad scaling
+ self.do_grad_scaling = self.amp_dtype == torch.float16
+ if self.do_grad_scaling:
+ if self.sharded_ddp is not None:
+ self.scaler = ShardedGradScaler()
+ elif self.fsdp is not None:
+ from torch.distributed.fsdp.sharded_grad_scaler import (
+ ShardedGradScaler as FSDPShardedGradScaler,
+ )
+
+ self.scaler = FSDPShardedGradScaler()
+ elif is_torch_tpu_available():
+ from torch_xla.amp import GradScaler
+
+ self.scaler = GradScaler()
+ else:
+ self.scaler = torch.cuda.amp.GradScaler()
+ elif args.half_precision_backend == "cpu_amp":
+ self.use_cpu_amp = True
+ self.amp_dtype = torch.bfloat16
+ elif args.half_precision_backend == "apex":
+ if not is_apex_available():
+ raise ImportError(
+ "Using FP16 with APEX but APEX is not installed, please refer to"
+ " https://www.github.com/nvidia/apex."
+ )
+ self.use_apex = True
+
+ # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error.
+ if (
+ is_sagemaker_mp_enabled()
+ and self.use_cuda_amp
+ and args.max_grad_norm is not None
+ and args.max_grad_norm > 0
+ ):
+ raise ValueError(
+ "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass "
+ "along 'max_grad_norm': 0 in your hyperparameters."
+ )
+
+ # Label smoothing
+ if self.args.label_smoothing_factor != 0:
+ self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
+ else:
+ self.label_smoother = None
+
+ self.state = TrainerState(
+ is_local_process_zero=self.is_local_process_zero(),
+ is_world_process_zero=self.is_world_process_zero(),
+ )
+
+ self.control = TrainerControl()
+ # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then
+ # returned to 0 every time flos need to be logged
+ self.current_flos = 0
+ self.hp_search_backend = None
+ self.use_tune_checkpoints = False
+ default_label_names = find_labels(self.model.__class__)
+ self.label_names = default_label_names if self.args.label_names is None else self.args.label_names
+ self.can_return_loss = can_return_loss(self.model.__class__)
+ self.control = self.callback_handler.on_init_end(self.args, self.state, self.control)
+
+ # Internal variables to keep track of the original batch size
+ self._train_batch_size = args.train_batch_size
+
+ # very last
+ self._memory_tracker.stop_and_update_metrics()
+
+ # torch.compile
+ if args.torch_compile and not is_torch_compile_available():
+ raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.")
+
+ def add_callback(self, callback):
+ """
+ Add a callback to the current list of [`~transformer.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will instantiate a member of that class.
+ """
+ self.callback_handler.add_callback(callback)
+
+ def pop_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it.
+
+ If the callback is not found, returns `None` (and no error is raised).
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will pop the first member of that class found in the list of callbacks.
+
+ Returns:
+ [`~transformer.TrainerCallback`]: The callback removed, if found.
+ """
+ return self.callback_handler.pop_callback(callback)
+
+ def remove_callback(self, callback):
+ """
+ Remove a callback from the current list of [`~transformer.TrainerCallback`].
+
+ Args:
+ callback (`type` or [`~transformer.TrainerCallback`]):
+ A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the
+ first case, will remove the first member of that class found in the list of callbacks.
+ """
+ self.callback_handler.remove_callback(callback)
+
+ def _move_model_to_device(self, model, device):
+ model = model.to(device)
+ # Moving a model to an XLA device disconnects the tied weights, so we have to retie them.
+ if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"):
+ model.tie_weights()
+
+ def _set_signature_columns_if_needed(self):
+ if self._signature_columns is None:
+ # Inspect model forward signature to keep only the arguments it accepts.
+ signature = inspect.signature(self.model.forward)
+ self._signature_columns = list(signature.parameters.keys())
+ # Labels may be named label or label_ids, the default data collator handles that.
+ self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
+
+ def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None):
+ if not self.args.remove_unused_columns:
+ return dataset
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ ignored_columns = list(set(dataset.column_names) - set(signature_columns))
+ if len(ignored_columns) > 0:
+ dset_description = "" if description is None else f"in the {description} set"
+ logger.info(
+ f"The following columns {dset_description} don't have a corresponding argument in "
+ f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
+ f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
+ " you can safely ignore this message."
+ )
+
+ columns = [k for k in signature_columns if k in dataset.column_names]
+
+ if version.parse(datasets.__version__) < version.parse("1.4.0"):
+ dataset.set_format(
+ type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
+ )
+ return dataset
+ else:
+ return dataset.remove_columns(ignored_columns)
+
+ def _get_collator_with_removed_columns(
+ self, data_collator: Callable, description: Optional[str] = None
+ ) -> Callable:
+ """Wrap the data collator in a callable removing unused columns."""
+ if not self.args.remove_unused_columns:
+ return data_collator
+ self._set_signature_columns_if_needed()
+ signature_columns = self._signature_columns
+
+ remove_columns_collator = RemoveColumnsCollator(
+ data_collator=data_collator,
+ signature_columns=signature_columns,
+ logger=logger,
+ description=description,
+ model_name=self.model.__class__.__name__,
+ )
+ return remove_columns_collator
+
+ def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
+ if self.train_dataset is None or not has_length(self.train_dataset):
+ return None
+
+ generator = None
+ if self.args.world_size <= 1:
+ generator = torch.Generator()
+ # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with
+ # `args.seed`) if data_seed isn't provided.
+ # Further on in this method, we default to `args.seed` instead.
+ if self.args.data_seed is None:
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.args.data_seed
+ generator.manual_seed(seed)
+
+ seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed
+
+ # Build the sampler.
+ if self.args.group_by_length:
+ if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset):
+ lengths = (
+ self.train_dataset[self.args.length_column_name]
+ if self.args.length_column_name in self.train_dataset.column_names
+ else None
+ )
+ else:
+ lengths = None
+ model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
+ if self.args.world_size <= 1:
+ return LengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ dataset=self.train_dataset,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ generator=generator,
+ )
+ else:
+ return DistributedLengthGroupedSampler(
+ self.args.train_batch_size * self.args.gradient_accumulation_steps,
+ dataset=self.train_dataset,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ lengths=lengths,
+ model_input_name=model_input_name,
+ seed=seed,
+ )
+
+ else:
+ if self.args.world_size <= 1:
+ return RandomSampler(self.train_dataset, generator=generator)
+ elif (
+ self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL]
+ and not self.args.dataloader_drop_last
+ ):
+ # Use a loop for TPUs when drop_last is False to have all batches have the same size.
+ return DistributedSamplerWithLoop(
+ self.train_dataset,
+ batch_size=self.args.per_device_train_batch_size,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ seed=seed,
+ )
+ else:
+ return DistributedSampler(
+ self.train_dataset,
+ num_replicas=self.args.world_size,
+ rank=self.args.process_index,
+ seed=seed,
+ )
+
+ def get_train_dataloader(self) -> DataLoader:
+ """
+ Returns the training [`~torch.utils.data.DataLoader`].
+
+ Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
+ training if necessary) otherwise.
+
+ Subclass and override this method if you want to inject some custom behavior.
+ """
+ if self.train_dataset is None:
+ raise ValueError("Trainer: training requires a train_dataset.")
+
+ train_dataset = self.train_dataset
+ data_collator = self.data_collator
+ if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
+ train_dataset = self._remove_unused_columns(train_dataset, description="training")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
+
+ if isinstance(train_dataset, torch.utils.data.IterableDataset):
+ if self.args.world_size > 1:
+ train_dataset = IterableDatasetShard(
+ train_dataset,
+ batch_size=self._train_batch_size,
+ drop_last=self.args.dataloader_drop_last,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+
+ return DataLoader(
+ train_dataset,
+ batch_size=self._train_batch_size,
+ collate_fn=data_collator,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ train_sampler = self._get_train_sampler()
+
+ return DataLoader(
+ train_dataset,
+ batch_size=self._train_batch_size,
+ sampler=train_sampler,
+ collate_fn=data_collator,
+ drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ worker_init_fn=seed_worker,
+ )
+
+ def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]:
+ # Deprecated code
+ if self.args.use_legacy_prediction_loop:
+ if is_torch_tpu_available():
+ return SequentialDistributedSampler(
+ eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal()
+ )
+ elif is_sagemaker_mp_enabled():
+ return SequentialDistributedSampler(
+ eval_dataset,
+ num_replicas=smp.dp_size(),
+ rank=smp.dp_rank(),
+ batch_size=self.args.per_device_eval_batch_size,
+ )
+ elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ return SequentialDistributedSampler(eval_dataset)
+ else:
+ return SequentialSampler(eval_dataset)
+
+ if self.args.world_size <= 1:
+ return SequentialSampler(eval_dataset)
+ else:
+ return ShardSampler(
+ eval_dataset,
+ batch_size=self.args.per_device_eval_batch_size,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+
+ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
+ """
+ Returns the evaluation [`~torch.utils.data.DataLoader`].
+
+ Subclass and override this method if you want to inject some custom behavior.
+
+ Args:
+ eval_dataset (`torch.utils.data.Dataset`, *optional*):
+ If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
+ by the `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ if eval_dataset is None and self.eval_dataset is None:
+ raise ValueError("Trainer: evaluation requires an eval_dataset.")
+ eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
+ data_collator = self.data_collator
+
+ if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
+ eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation")
+
+ if isinstance(eval_dataset, torch.utils.data.IterableDataset):
+ if self.args.world_size > 1:
+ eval_dataset = IterableDatasetShard(
+ eval_dataset,
+ batch_size=self.args.per_device_eval_batch_size,
+ drop_last=self.args.dataloader_drop_last,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+ return DataLoader(
+ eval_dataset,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ eval_sampler = self._get_eval_sampler(eval_dataset)
+
+ return DataLoader(
+ eval_dataset,
+ sampler=eval_sampler,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
+ """
+ Returns the test [`~torch.utils.data.DataLoader`].
+
+ Subclass and override this method if you want to inject some custom behavior.
+
+ Args:
+ test_dataset (`torch.utils.data.Dataset`, *optional*):
+ The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
+ `model.forward()` method are automatically removed. It must implement `__len__`.
+ """
+ data_collator = self.data_collator
+
+ if is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
+ test_dataset = self._remove_unused_columns(test_dataset, description="test")
+ else:
+ data_collator = self._get_collator_with_removed_columns(data_collator, description="test")
+
+ if isinstance(test_dataset, torch.utils.data.IterableDataset):
+ if self.args.world_size > 1:
+ test_dataset = IterableDatasetShard(
+ test_dataset,
+ batch_size=self.args.eval_batch_size,
+ drop_last=self.args.dataloader_drop_last,
+ num_processes=self.args.world_size,
+ process_index=self.args.process_index,
+ )
+ return DataLoader(
+ test_dataset,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ test_sampler = self._get_eval_sampler(test_dataset)
+
+ # We use the same batch_size as for eval.
+ return DataLoader(
+ test_dataset,
+ sampler=test_sampler,
+ batch_size=self.args.eval_batch_size,
+ collate_fn=data_collator,
+ drop_last=self.args.dataloader_drop_last,
+ num_workers=self.args.dataloader_num_workers,
+ pin_memory=self.args.dataloader_pin_memory,
+ )
+
+ def create_optimizer_and_scheduler(self, num_training_steps: int):
+ """
+ Setup the optimizer and the learning rate scheduler.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or
+ `create_scheduler`) in a subclass.
+ """
+ self.create_optimizer()
+ if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16:
+ # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer
+ optimizer = self.optimizer.optimizer
+ else:
+ optimizer = self.optimizer
+ self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer)
+
+ def create_optimizer(self):
+ """
+ Setup the optimizer.
+
+ We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
+ Trainer's init through `optimizers`, or subclass and override this method in a subclass.
+ """
+ opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+
+ if self.optimizer is None:
+ decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
+ decay_parameters = [name for name in decay_parameters if "bias" not in name]
+ optimizer_grouped_parameters = [
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": self.args.weight_decay,
+ },
+ {
+ "params": [
+ p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
+ ],
+ "weight_decay": 0.0,
+ },
+ ]
+
+ optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
+
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer = OSS(
+ params=optimizer_grouped_parameters,
+ optim=optimizer_cls,
+ **optimizer_kwargs,
+ )
+ else:
+ self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
+ if optimizer_cls.__name__ == "Adam8bit":
+ import bitsandbytes
+
+ manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
+
+ skipped = 0
+ for module in opt_model.modules():
+ if isinstance(module, nn.Embedding):
+ skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
+ logger.info(f"skipped {module}: {skipped/2**20}M params")
+ manager.register_module_override(module, "weight", {"optim_bits": 32})
+ logger.debug(f"bitsandbytes: will optimize {module} in fp32")
+ logger.info(f"skipped: {skipped/2**20}M params")
+
+ if is_sagemaker_mp_enabled():
+ self.optimizer = smp.DistributedOptimizer(self.optimizer)
+
+ return self.optimizer
+
+ @staticmethod
+ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]:
+ """
+ Returns the optimizer class and optimizer parameters based on the training arguments.
+
+ Args:
+ args (`transformers.training_args.TrainingArguments`):
+ The training arguments for the training session.
+
+ """
+
+ # parse args.optim_args
+ optim_args = {}
+ if args.optim_args:
+ for mapping in args.optim_args.replace(" ", "").split(","):
+ key, value = mapping.split("=")
+ optim_args[key] = value
+
+ optimizer_kwargs = {"lr": args.learning_rate}
+
+ adam_kwargs = {
+ "betas": (args.adam_beta1, args.adam_beta2),
+ "eps": args.adam_epsilon,
+ }
+ if args.optim == OptimizerNames.ADAFACTOR:
+ optimizer_cls = Adafactor
+ optimizer_kwargs.update({"scale_parameter": False, "relative_step": False})
+ elif args.optim == OptimizerNames.ADAMW_HF:
+ from transformers.optimization import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]:
+ from torch.optim import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ if args.optim == OptimizerNames.ADAMW_TORCH_FUSED:
+ optimizer_kwargs.update({"fused": True})
+ elif args.optim == OptimizerNames.ADAMW_TORCH_XLA:
+ try:
+ from torch_xla.amp.syncfree import AdamW
+
+ optimizer_cls = AdamW
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.")
+ elif args.optim == OptimizerNames.ADAMW_APEX_FUSED:
+ try:
+ from apex.optimizers import FusedAdam
+
+ optimizer_cls = FusedAdam
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!")
+ elif args.optim in [
+ OptimizerNames.ADAMW_BNB,
+ OptimizerNames.ADAMW_8BIT,
+ OptimizerNames.PAGED_ADAMW,
+ OptimizerNames.PAGED_ADAMW_8BIT,
+ OptimizerNames.LION,
+ OptimizerNames.LION_8BIT,
+ OptimizerNames.PAGED_LION,
+ OptimizerNames.PAGED_LION_8BIT,
+ ]:
+ try:
+ from bitsandbytes.optim import AdamW, Lion
+
+ is_paged = False
+ optim_bits = 32
+ optimizer_cls = None
+ additional_optim_kwargs = adam_kwargs
+ if "paged" in args.optim:
+ is_paged = True
+ if "8bit" in args.optim:
+ optim_bits = 8
+ if "adam" in args.optim:
+ optimizer_cls = AdamW
+ elif "lion" in args.optim:
+ optimizer_cls = Lion
+ additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)}
+
+ bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits}
+ optimizer_kwargs.update(additional_optim_kwargs)
+ optimizer_kwargs.update(bnb_kwargs)
+ except ImportError:
+ raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!")
+ elif args.optim == OptimizerNames.ADAMW_BNB:
+ try:
+ from bitsandbytes.optim import Adam8bit
+
+ optimizer_cls = Adam8bit
+ optimizer_kwargs.update(adam_kwargs)
+ except ImportError:
+ raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!")
+ elif args.optim == OptimizerNames.ADAMW_ANYPRECISION:
+ try:
+ from torchdistx.optimizers import AnyPrecisionAdamW
+
+ optimizer_cls = AnyPrecisionAdamW
+ optimizer_kwargs.update(adam_kwargs)
+
+ # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx.
+ optimizer_kwargs.update(
+ {
+ "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")),
+ "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")),
+ "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")),
+ "compensation_buffer_dtype": getattr(
+ torch, optim_args.get("compensation_buffer_dtype", "bfloat16")
+ ),
+ }
+ )
+ except ImportError:
+ raise ValueError("Please install https://github.com/pytorch/torchdistx")
+ elif args.optim == OptimizerNames.SGD:
+ optimizer_cls = torch.optim.SGD
+ elif args.optim == OptimizerNames.ADAGRAD:
+ optimizer_cls = torch.optim.Adagrad
+ else:
+ raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}")
+ return optimizer_cls, optimizer_kwargs
+
+ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
+ """
+ Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
+ passed as an argument.
+
+ Args:
+ num_training_steps (int): The number of training steps to do.
+ """
+ if self.lr_scheduler is None:
+ self.lr_scheduler = get_scheduler(
+ self.args.lr_scheduler_type,
+ optimizer=self.optimizer if optimizer is None else optimizer,
+ num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
+ num_training_steps=num_training_steps,
+ )
+ return self.lr_scheduler
+
+ def num_examples(self, dataloader: DataLoader) -> int:
+ """
+ Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When
+ dataloader.dataset does not exist or has no length, estimates as best it can
+ """
+ try:
+ dataset = dataloader.dataset
+ # Special case for IterableDatasetShard, we need to dig deeper
+ if isinstance(dataset, IterableDatasetShard):
+ return len(dataloader.dataset.dataset)
+ return len(dataloader.dataset)
+ except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader
+ return len(dataloader) * self.args.per_device_train_batch_size
+
+ def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]):
+ """HP search setup code"""
+ self._trial = trial
+
+ if self.hp_search_backend is None or trial is None:
+ return
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ params = self.hp_space(trial)
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ params = trial
+ params.pop("wandb", None)
+ elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+ params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()}
+ elif self.hp_search_backend == HPSearchBackend.WANDB:
+ params = trial
+
+ for key, value in params.items():
+ if not hasattr(self.args, key):
+ logger.warning(
+ f"Trying to set {key} in the hyperparameter search but there is no corresponding field in"
+ " `TrainingArguments`."
+ )
+ continue
+ old_attr = getattr(self.args, key, None)
+ # Casting value to the proper type
+ if old_attr is not None:
+ value = type(old_attr)(value)
+ setattr(self.args, key, value)
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ logger.info(f"Trial: {trial.params}")
+ if self.hp_search_backend == HPSearchBackend.SIGOPT:
+ logger.info(f"SigOpt Assignments: {trial.assignments}")
+ if self.hp_search_backend == HPSearchBackend.WANDB:
+ logger.info(f"W&B Sweep parameters: {trial}")
+ if self.is_deepspeed_enabled:
+ if self.args.deepspeed is None:
+ raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set")
+ # Rebuild the deepspeed config to reflect the updated training parameters
+ from accelerate.utils import DeepSpeedPlugin
+
+ from transformers.deepspeed import HfTrainerDeepSpeedConfig
+
+ self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed)
+ self.args.hf_deepspeed_config.trainer_config_process(self.args)
+ self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config)
+ self.create_accelerator_and_postprocess()
+
+ def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]):
+ if self.hp_search_backend is None or trial is None:
+ return
+ self.objective = self.compute_objective(metrics.copy())
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ import optuna
+
+ trial.report(self.objective, step)
+ if trial.should_prune():
+ self.callback_handler.on_train_end(self.args, self.state, self.control)
+ raise optuna.TrialPruned()
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ from ray import tune
+
+ if self.control.should_save:
+ self._tune_save_checkpoint()
+ tune.report(objective=self.objective, **metrics)
+
+ def _tune_save_checkpoint(self):
+ from ray import tune
+
+ if not self.use_tune_checkpoints:
+ return
+ with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
+ output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
+ self.save_model(output_dir, _internal_call=True)
+ if self.args.should_save:
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+
+ def call_model_init(self, trial=None):
+ model_init_argcount = number_of_arguments(self.model_init)
+ if model_init_argcount == 0:
+ model = self.model_init()
+ elif model_init_argcount == 1:
+ model = self.model_init(trial)
+ else:
+ raise RuntimeError("model_init should have 0 or 1 argument.")
+
+ if model is None:
+ raise RuntimeError("model_init should not return None.")
+
+ return model
+
+ def torch_jit_model_eval(self, model, dataloader, training=False):
+ if not training:
+ if dataloader is None:
+ logger.warning("failed to use PyTorch jit mode due to current dataloader is none.")
+ return model
+ example_batch = next(iter(dataloader))
+ example_batch = self._prepare_inputs(example_batch)
+ try:
+ jit_model = model.eval()
+ with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]):
+ if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"):
+ if isinstance(example_batch, dict):
+ jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
+ else:
+ jit_model = torch.jit.trace(
+ jit_model,
+ example_kwarg_inputs={key: example_batch[key] for key in example_batch},
+ strict=False,
+ )
+ else:
+ jit_inputs = []
+ for key in example_batch:
+ example_tensor = torch.ones_like(example_batch[key])
+ jit_inputs.append(example_tensor)
+ jit_inputs = tuple(jit_inputs)
+ jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False)
+ jit_model = torch.jit.freeze(jit_model)
+ with torch.no_grad():
+ jit_model(**example_batch)
+ jit_model(**example_batch)
+ model = jit_model
+ self.use_cpu_amp = False
+ self.use_cuda_amp = False
+ except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e:
+ logger.warning(f"failed to use PyTorch jit mode due to: {e}.")
+
+ return model
+
+ def ipex_optimize_model(self, model, training=False, dtype=torch.float32):
+ if not is_ipex_available():
+ raise ImportError(
+ "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer"
+ " to https://github.com/intel/intel-extension-for-pytorch."
+ )
+
+ import intel_extension_for_pytorch as ipex
+
+ if not training:
+ model.eval()
+ dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype
+ # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings
+ model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train)
+ else:
+ if not model.training:
+ model.train()
+ model, self.optimizer = ipex.optimize(
+ model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1"
+ )
+
+ return model
+
+ def _wrap_model(self, model, training=True, dataloader=None):
+ if self.args.use_ipex:
+ dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32
+ model = self.ipex_optimize_model(model, training, dtype=dtype)
+
+ if is_sagemaker_mp_enabled():
+ # Wrapping the base model twice in a DistributedModel will raise an error.
+ if isinstance(self.model_wrapped, smp.model.DistributedModel):
+ return self.model_wrapped
+ return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps)
+
+ # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again
+ if unwrap_model(model) is not model:
+ return model
+
+ # Mixed precision training with apex (torch < 1.6)
+ if self.use_apex and training:
+ model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)
+
+ # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP
+ if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False):
+ model = nn.DataParallel(model)
+
+ if self.args.jit_mode_eval:
+ start_time = time.time()
+ model = self.torch_jit_model_eval(model, dataloader, training)
+ self.jit_compilation_time = round(time.time() - start_time, 4)
+
+ # Note: in torch.distributed mode, there's no point in wrapping the model
+ # inside a DistributedDataParallel as we'll be under `no_grad` anyways.
+ if not training:
+ return model
+
+ # Distributed training (should be after apex fp16 initialization)
+ if self.sharded_ddp is not None:
+ # Sharded DDP!
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ model = ShardedDDP(model, self.optimizer)
+ else:
+ mixed_precision = self.args.fp16 or self.args.bf16
+ cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
+ zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
+ # XXX: Breaking the self.model convention but I see no way around it for now.
+ if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp:
+ model = auto_wrap(model)
+ self.model = model = FullyShardedDDP(
+ model,
+ mixed_precision=mixed_precision,
+ reshard_after_forward=zero_3,
+ cpu_offload=cpu_offload,
+ ).to(self.args.device)
+ # Distributed training using PyTorch FSDP
+ elif self.fsdp is not None and self.args.fsdp_config["xla"]:
+ try:
+ from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
+ from torch_xla.distributed.fsdp import checkpoint_module
+ from torch_xla.distributed.fsdp.wrap import (
+ size_based_auto_wrap_policy,
+ transformer_auto_wrap_policy,
+ )
+ except ImportError:
+ raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
+ auto_wrap_policy = None
+ auto_wrapper_callable = None
+ if self.args.fsdp_config["fsdp_min_num_params"] > 0:
+ auto_wrap_policy = functools.partial(
+ size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
+ )
+ elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
+ transformer_cls_to_wrap = set()
+ for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
+ transformer_cls = get_module_class_from_name(model, layer_class)
+ if transformer_cls is None:
+ raise Exception("Could not find the transformer layer class to wrap in the model.")
+ else:
+ transformer_cls_to_wrap.add(transformer_cls)
+ auto_wrap_policy = functools.partial(
+ transformer_auto_wrap_policy,
+ # Transformer layer class to wrap
+ transformer_layer_cls=transformer_cls_to_wrap,
+ )
+ fsdp_kwargs = self.args.xla_fsdp_config
+ if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
+ # Apply gradient checkpointing to auto-wrapped sub-modules if specified
+ def auto_wrapper_callable(m, *args, **kwargs):
+ return FSDP(checkpoint_module(m), *args, **kwargs)
+
+ # Wrap the base model with an outer FSDP wrapper
+ self.model = model = FSDP(
+ model,
+ auto_wrap_policy=auto_wrap_policy,
+ auto_wrapper_callable=auto_wrapper_callable,
+ **fsdp_kwargs,
+ )
+
+ # Patch `xm.optimizer_step` should not reduce gradients in this case,
+ # as FSDP does not need gradient reduction over sharded parameters.
+ def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
+ loss = optimizer.step(**optimizer_args)
+ if barrier:
+ xm.mark_step()
+ return loss
+
+ xm.optimizer_step = patched_optimizer_step
+ elif is_sagemaker_dp_enabled():
+ model = nn.parallel.DistributedDataParallel(
+ model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
+ )
+ elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ if is_torch_neuroncore_available():
+ return model
+ kwargs = {}
+ if self.args.ddp_find_unused_parameters is not None:
+ kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters
+ elif isinstance(model, PreTrainedModel):
+ # find_unused_parameters breaks checkpointing as per
+ # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
+ kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing
+ else:
+ kwargs["find_unused_parameters"] = True
+
+ if self.args.ddp_bucket_cap_mb is not None:
+ kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb
+
+ self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs)
+
+ return model
+
+ def train(
+ self,
+ resume_from_checkpoint: Optional[Union[str, bool]] = None,
+ trial: Union["optuna.Trial", Dict[str, Any]] = None,
+ ignore_keys_for_eval: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ """
+ Main training entry point.
+
+ Args:
+ resume_from_checkpoint (`str` or `bool`, *optional*):
+ If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a
+ `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance
+ of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here.
+ trial (`optuna.Trial` or `Dict[str, Any]`, *optional*):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ ignore_keys_for_eval (`List[str]`, *optional*)
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions for evaluation during the training.
+ kwargs:
+ Additional keyword arguments used to hide deprecated arguments
+ """
+ if resume_from_checkpoint is False:
+ resume_from_checkpoint = None
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ args = self.args
+
+ self.is_in_train = True
+
+ # do_train is not a reliable argument, as it might not be set and .train() still called, so
+ # the following is a workaround:
+ if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
+ self._move_model_to_device(self.model, args.device)
+
+ if "model_path" in kwargs:
+ resume_from_checkpoint = kwargs.pop("model_path")
+ warnings.warn(
+ "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
+ "instead.",
+ FutureWarning,
+ )
+ if len(kwargs) > 0:
+ raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
+ # This might change the seed so needs to run first.
+ self._hp_search_setup(trial)
+ self._train_batch_size = self.args.train_batch_size
+
+ # Model re-init
+ model_reloaded = False
+ if self.model_init is not None:
+ # Seed must be set before instantiating the model when using model_init.
+ enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed)
+ self.model = self.call_model_init(trial)
+ model_reloaded = True
+ # Reinitializes optimizer and scheduler
+ self.optimizer, self.lr_scheduler = None, None
+
+ # Load potential model checkpoint
+ if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+ resume_from_checkpoint = get_last_checkpoint(args.output_dir)
+ if resume_from_checkpoint is None:
+ raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
+
+ if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
+ self._load_from_checkpoint(resume_from_checkpoint)
+
+ # If model was re-initialized, put it on the right device and update self.model_wrapped
+ if model_reloaded:
+ if self.place_model_on_device:
+ self._move_model_to_device(self.model, args.device)
+ self.model_wrapped = self.model
+
+ inner_training_loop = find_executable_batch_size(
+ self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size
+ )
+
+ return inner_training_loop(
+ args=args,
+ resume_from_checkpoint=resume_from_checkpoint,
+ trial=trial,
+ ignore_keys_for_eval=ignore_keys_for_eval,
+ )
+
+ def _inner_training_loop(
+ self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
+ ):
+ self.accelerator.free_memory()
+ self._train_batch_size = batch_size
+ logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+
+ # Setting up training control variables:
+ # number of training epochs: num_train_epochs
+ # number of training steps per epoch: num_update_steps_per_epoch
+ # total number of training steps to execute: max_steps
+ total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
+
+ len_dataloader = None
+ if has_length(train_dataloader):
+ len_dataloader = len(train_dataloader)
+ num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ num_examples = self.num_examples(train_dataloader)
+ if args.max_steps > 0:
+ max_steps = args.max_steps
+ num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+ args.max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
+ # the best we can do.
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
+ elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
+ max_steps = args.max_steps
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_examples = total_train_batch_size * args.max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ raise ValueError(
+ "args.max_steps must be set to a positive value if dataloader does not have a length, was"
+ f" {args.max_steps}"
+ )
+
+ # Compute absolute values for logging, eval, and save if given as ratio
+ if args.logging_steps and args.logging_steps < 1:
+ args.logging_steps = math.ceil(max_steps * args.logging_steps)
+ if args.eval_steps and args.eval_steps < 1:
+ args.eval_steps = math.ceil(max_steps * args.eval_steps)
+ if args.save_steps and args.save_steps < 1:
+ args.save_steps = math.ceil(max_steps * args.save_steps)
+
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+ if self.args.n_gpu > 1:
+ # nn.DataParallel(model) replicates the model, creating new variables and module
+ # references registered here no longer work on other gpus, breaking the module
+ raise ValueError(
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
+ " (torch.distributed.launch)."
+ )
+ else:
+ debug_overflow = DebugUnderflowOverflow(self.model) # noqa
+
+ delay_optimizer_creation = (
+ self.sharded_ddp is not None
+ and self.sharded_ddp != ShardedDDPOption.SIMPLE
+ or is_sagemaker_mp_enabled()
+ or self.fsdp is not None
+ )
+
+ if self.is_deepspeed_enabled:
+ self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)
+
+ if not delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ self.state = TrainerState()
+ self.state.is_hyper_param_search = trial is not None
+
+ # Activate gradient checkpointing if needed
+ if args.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable()
+
+ model = self._wrap_model(self.model_wrapped)
+
+ if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
+ self._load_from_checkpoint(resume_from_checkpoint, model)
+
+ # as the model is wrapped, don't use `accelerator.prepare`
+ # this is for unhandled cases such as
+ # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
+ use_accelerator_prepare = True if model is self.model else False
+
+ if delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ # prepare using `accelerator` prepare
+ if use_accelerator_prepare:
+ if hasattr(self.lr_scheduler, "step"):
+ if self.use_apex:
+ model = self.accelerator.prepare(self.model)
+ else:
+ model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
+ else:
+ # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
+ model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
+ self.model, self.optimizer, self.lr_scheduler
+ )
+
+ if self.is_fsdp_enabled:
+ self.model = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # deepspeed ckpt loading
+ if resume_from_checkpoint is not None and self.is_deepspeed_enabled:
+ deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint)
+
+ # Check if saved optimizer or scheduler states exist
+ self._load_optimizer_and_scheduler(resume_from_checkpoint)
+
+ # important: at this point:
+ # self.model is the Transformers Model
+ # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
+
+ # Train!
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples:,}")
+ logger.info(f" Num Epochs = {num_train_epochs:,}")
+ logger.info(f" Instantaneous batch size per device = {self._train_batch_size:,}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps:,}")
+ logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")
+
+ self.state.epoch = 0
+ start_time = time.time()
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+ steps_trained_progress_bar = None
+
+ # Check if continuing training from a checkpoint
+ if resume_from_checkpoint is not None and os.path.isfile(
+ os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+ ):
+ self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+ epochs_trained = self.state.global_step // num_update_steps_per_epoch
+ if not args.ignore_data_skip:
+ steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
+ steps_trained_in_current_epoch *= args.gradient_accumulation_steps
+ else:
+ steps_trained_in_current_epoch = 0
+
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
+ logger.info(f" Continuing training from epoch {epochs_trained}")
+ logger.info(f" Continuing training from global step {self.state.global_step}")
+ if not args.ignore_data_skip:
+ if skip_first_batches is None:
+ logger.info(
+ f" Will skip the first {epochs_trained} epochs then the first"
+ f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time,"
+ " you can install the latest version of Accelerate with `pip install -U accelerate`.You can"
+ " also add the `--ignore_data_skip` flag to your launch command, but you will resume the"
+ " training on data already seen by your model."
+ )
+ else:
+ logger.info(
+ f" Will skip the first {epochs_trained} epochs then the first"
+ f" {steps_trained_in_current_epoch} batches in the first epoch."
+ )
+ if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None:
+ steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
+ steps_trained_progress_bar.set_description("Skipping the first batches")
+
+ # Update the references
+ self.callback_handler.model = self.model
+ self.callback_handler.optimizer = self.optimizer
+ self.callback_handler.lr_scheduler = self.lr_scheduler
+ self.callback_handler.train_dataloader = train_dataloader
+ if self.hp_name is not None and self._trial is not None:
+ # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
+ # parameter to Train when using DDP.
+ self.state.trial_name = self.hp_name(self._trial)
+ if trial is not None:
+ assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
+ self.state.trial_params = hp_params(assignments)
+ else:
+ self.state.trial_params = None
+ # This should be the same if the state has been saved but in case the training arguments changed, it's safer
+ # to set this after the load.
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+ tr_loss = torch.tensor(0.0).to(args.device)
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+ self._total_loss_scalar = 0.0
+ self._globalstep_last_logged = self.state.global_step
+ model.zero_grad()
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
+ if not args.ignore_data_skip:
+ for epoch in range(epochs_trained):
+ is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance(
+ train_dataloader.sampler, RandomSampler
+ )
+ if is_torch_less_than_1_11 or not is_random_sampler:
+ # We just need to begin an iteration to create the randomization of the sampler.
+ # That was before PyTorch 1.11 however...
+ for _ in train_dataloader:
+ break
+ else:
+ # Otherwise we need to call the whooooole sampler cause there is some random operation added
+ # AT THE VERY END!
+ _ = list(train_dataloader.sampler)
+
+ total_batched_samples = 0
+ for epoch in range(epochs_trained, num_train_epochs):
+ if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
+ train_dataloader.sampler.set_epoch(epoch)
+ elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard):
+ train_dataloader.dataset.set_epoch(epoch)
+
+ if is_torch_tpu_available():
+ parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
+ epoch_iterator = parallel_loader
+ else:
+ epoch_iterator = train_dataloader
+
+ # Reset the past mems state at the beginning of each epoch if necessary.
+ if args.past_index >= 0:
+ self._past = None
+
+ steps_in_epoch = (
+ len(epoch_iterator)
+ if len_dataloader is not None
+ else args.max_steps * args.gradient_accumulation_steps
+ )
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+
+ rng_to_sync = False
+ steps_skipped = 0
+ if skip_first_batches is not None and steps_trained_in_current_epoch > 0:
+ epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
+ steps_skipped = steps_trained_in_current_epoch
+ steps_trained_in_current_epoch = 0
+ rng_to_sync = True
+
+ step = -1
+ for step, inputs in enumerate(epoch_iterator):
+ total_batched_samples += 1
+ if rng_to_sync:
+ self._load_rng_state(resume_from_checkpoint)
+ rng_to_sync = False
+
+ # Skip past any already trained steps if resuming training
+ if steps_trained_in_current_epoch > 0:
+ steps_trained_in_current_epoch -= 1
+ if steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.update(1)
+ if steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+ continue
+ elif steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.close()
+ steps_trained_progress_bar = None
+
+ if step % args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ with self.accelerator.accumulate(model):
+ tr_loss_step = self.training_step(model, inputs)
+
+ if (
+ args.logging_nan_inf_filter
+ and not is_torch_tpu_available()
+ and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+ ):
+ # if loss is nan or inf simply add the average of previous logged losses
+ tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+ else:
+ tr_loss += tr_loss_step
+
+ self.current_flos += float(self.floating_point_ops(inputs))
+
+ # should this be under the accumulate context manager?
+ # the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered
+ # in accelerate
+ if total_batched_samples % args.gradient_accumulation_steps == 0 or (
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
+ steps_in_epoch <= args.gradient_accumulation_steps
+ and (step + 1) == steps_in_epoch
+ ):
+ # Gradient clipping
+ if args.max_grad_norm is not None and args.max_grad_norm > 0:
+ # deepspeed does its own clipping
+
+ if self.do_grad_scaling:
+ # Reduce gradients first for XLA
+ if is_torch_tpu_available():
+ gradients = xm._fetch_gradients(self.optimizer)
+ xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size())
+ # AMP: gradients need unscaling
+ self.scaler.unscale_(self.optimizer)
+
+ if is_sagemaker_mp_enabled() and args.fp16:
+ self.optimizer.clip_master_grads(args.max_grad_norm)
+ elif hasattr(self.optimizer, "clip_grad_norm"):
+ # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
+ self.optimizer.clip_grad_norm(args.max_grad_norm)
+ elif hasattr(model, "clip_grad_norm_"):
+ # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
+ model.clip_grad_norm_(args.max_grad_norm)
+ elif self.use_apex:
+ # Revert to normal clipping otherwise, handling Apex or full precision
+ nn.utils.clip_grad_norm_(
+ amp.master_params(self.optimizer),
+ args.max_grad_norm,
+ )
+ else:
+ self.accelerator.clip_grad_norm_(
+ model.parameters(),
+ args.max_grad_norm,
+ )
+
+ # Optimizer step
+ optimizer_was_run = True
+ if is_torch_tpu_available():
+ if self.do_grad_scaling:
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ else:
+ xm.optimizer_step(self.optimizer)
+ elif self.do_grad_scaling:
+ scale_before = self.scaler.get_scale()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ scale_after = self.scaler.get_scale()
+ optimizer_was_run = scale_before <= scale_after
+ else:
+ self.optimizer.step()
+ optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
+
+ if optimizer_was_run:
+ # Delay optimizer scheduling until metrics are generated
+ if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ self.lr_scheduler.step()
+
+ model.zero_grad()
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+
+ self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
+ else:
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+ if step < 0:
+ logger.warning(
+ "There seems to be not a single sample in your epoch_iterator, stopping training at step"
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+ f" num_steps ({max_steps}) higher than the number of available samples."
+ )
+ self.control.should_training_stop = True
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ if is_torch_tpu_available():
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+ else:
+ logger.warning(
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+ "configured. Check your training configuration if this is unexpected."
+ )
+ if self.control.should_training_stop:
+ break
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of training
+ delattr(self, "_past")
+
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+ # Wait for everyone to get here so we are sur the model has been saved by process 0.
+ if is_torch_tpu_available():
+ xm.rendezvous("load_best_model_at_end")
+ elif args.parallel_mode == ParallelMode.DISTRIBUTED:
+ dist.barrier()
+ elif is_sagemaker_mp_enabled():
+ smp.barrier()
+
+ self._load_best_model()
+
+ # add remaining tr_loss
+ self._total_loss_scalar += tr_loss.item()
+ train_loss = self._total_loss_scalar / self.state.global_step
+
+ metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
+ self.store_flos()
+ metrics["total_flos"] = self.state.total_flos
+ metrics["train_loss"] = train_loss
+
+ self.is_in_train = False
+
+ self._memory_tracker.stop_and_update_metrics(metrics)
+
+ self.log(metrics)
+
+ run_dir = self._get_output_dir(trial)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)
+
+ # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
+ if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
+ for checkpoint in checkpoints_sorted:
+ if checkpoint != self.state.best_model_checkpoint:
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+ return TrainOutput(self.state.global_step, train_loss, metrics)
+
+ def _get_output_dir(self, trial):
+ if self.hp_search_backend is not None and trial is not None:
+ if self.hp_search_backend == HPSearchBackend.OPTUNA:
+ run_id = trial.number
+ elif self.hp_search_backend == HPSearchBackend.RAY:
+ from ray import tune
+
+ run_id = tune.get_trial_id()
+ elif self.hp_search_backend == HPSearchBackend.SIGOPT:
+ run_id = trial.id
+ elif self.hp_search_backend == HPSearchBackend.WANDB:
+ import wandb
+
+ run_id = wandb.run.id
+ run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}"
+ run_dir = os.path.join(self.args.output_dir, run_name)
+ else:
+ run_dir = self.args.output_dir
+ return run_dir
+
+ def _load_from_checkpoint(self, resume_from_checkpoint, model=None):
+ if model is None:
+ model = self.model
+
+ config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME)
+
+ weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME)
+ weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME)
+ safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME)
+ safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME)
+
+ if not any(
+ os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file]
+ ):
+ raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
+
+ logger.info(f"Loading model from {resume_from_checkpoint}.")
+
+ if os.path.isfile(config_file):
+ config = PretrainedConfig.from_json_file(config_file)
+ checkpoint_version = config.transformers_version
+ if checkpoint_version is not None and checkpoint_version != __version__:
+ logger.warning(
+ f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
+ f"Transformers but your current version is {__version__}. This is not recommended and could "
+ "yield to errors or unwanted behaviors."
+ )
+
+ if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file):
+ # If the model is on the GPU, it still works!
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ if hasattr(self.args, "fp16") and self.args.fp16 is True:
+ logger.warning(
+ "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported."
+ )
+ state_dict = torch.load(weights_file, map_location="cpu")
+ # Required for smp to not auto-translate state_dict from hf to smp (is already smp).
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ # release memory
+ del state_dict
+ elif self.is_fsdp_enabled:
+ self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint)
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ if self.args.save_safetensors and os.path.isfile(safe_weights_file):
+ state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu")
+ else:
+ state_dict = torch.load(weights_file, map_location="cpu")
+
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+ # which takes *args instead of **kwargs
+ load_result = model.load_state_dict(state_dict, False)
+ # release memory
+ del state_dict
+ self._issue_warnings_after_load(load_result)
+ else:
+ # We load the sharded checkpoint
+ load_result = load_sharded_checkpoint(
+ model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors
+ )
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+
+ def _load_best_model(self):
+ logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
+ best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+ best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME)
+ best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME)
+ best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME)
+
+ model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
+ if (
+ os.path.exists(best_model_path)
+ or os.path.exists(best_safe_model_path)
+ or os.path.exists(best_adapter_model_path)
+ or os.path.exists(best_safe_adapter_model_path)
+ ):
+ if self.is_deepspeed_enabled:
+ deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint)
+ else:
+ has_been_loaded = True
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")):
+ # If the 'user_content.pt' file exists, load with the new smp api.
+ # Checkpoint must have been saved with the new smp api.
+ smp.resume_from_checkpoint(
+ path=self.state.best_model_checkpoint,
+ tag=WEIGHTS_NAME,
+ partial=False,
+ load_optimizer=False,
+ )
+ else:
+ # If the 'user_content.pt' file does NOT exist, load with the old smp api.
+ # Checkpoint must have been saved with the old smp api.
+ if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+ state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+ else:
+ state_dict = torch.load(best_model_path, map_location="cpu")
+
+ state_dict["_smp_is_partial"] = False
+ load_result = model.load_state_dict(state_dict, strict=True)
+ elif self.is_fsdp_enabled:
+ self.accelerator.state.fsdp_plugin.load_model(
+ self.accelerator, model, self.state.best_model_checkpoint
+ )
+ else:
+ if is_peft_available() and isinstance(model, PeftModel):
+ # If train a model using PEFT & LoRA, assume that adapter have been saved properly.
+ if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"):
+ if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path):
+ model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
+ # Load_adapter has no return value present, modify it when appropriate.
+ from torch.nn.modules.module import _IncompatibleKeys
+
+ load_result = _IncompatibleKeys([], [])
+ else:
+ logger.warning(
+ "The intermediate checkpoints of PEFT may not be saved correctly, "
+ f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, "
+ "here are some examples https://github.com/huggingface/peft/issues/96"
+ )
+ has_been_loaded = False
+ else:
+ logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed")
+ has_been_loaded = False
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ if self.args.save_safetensors and os.path.isfile(best_safe_model_path):
+ state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu")
+ else:
+ state_dict = torch.load(best_model_path, map_location="cpu")
+
+ # If the model is on the GPU, it still works!
+ # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963
+ # which takes *args instead of **kwargs
+ load_result = model.load_state_dict(state_dict, False)
+ if not is_sagemaker_mp_enabled() and has_been_loaded:
+ self._issue_warnings_after_load(load_result)
+ elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)):
+ load_result = load_sharded_checkpoint(
+ model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled()
+ )
+ if not is_sagemaker_mp_enabled():
+ self._issue_warnings_after_load(load_result)
+ else:
+ logger.warning(
+ f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
+ "on multiple nodes, you should activate `--save_on_each_node`."
+ )
+
+ def _issue_warnings_after_load(self, load_result):
+ if len(load_result.missing_keys) != 0:
+ if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set(
+ self.model._keys_to_ignore_on_save
+ ):
+ self.model.tie_weights()
+ else:
+ logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.")
+ if len(load_result.unexpected_keys) != 0:
+ logger.warning(
+ f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}."
+ )
+
+ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
+ if self.control.should_log:
+ if is_torch_tpu_available():
+ xm.mark_step()
+
+ logs: Dict[str, float] = {}
+
+ # all_gather + mean() to get average loss over all processes
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+
+ # reset tr_loss to zero
+ tr_loss -= tr_loss
+
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+ logs["learning_rate"] = self._get_learning_rate()
+
+ self._total_loss_scalar += tr_loss_scalar
+ self._globalstep_last_logged = self.state.global_step
+ self.store_flos()
+
+ self.log(logs)
+
+ metrics = None
+ if self.control.should_evaluate:
+ if isinstance(self.eval_dataset, dict):
+ metrics = {}
+ for eval_dataset_name, eval_dataset in self.eval_dataset.items():
+ dataset_metrics = self.evaluate(
+ eval_dataset=eval_dataset,
+ ignore_keys=ignore_keys_for_eval,
+ metric_key_prefix=f"eval_{eval_dataset_name}",
+ )
+ metrics.update(dataset_metrics)
+ else:
+ metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
+ self._report_to_hp_search(trial, self.state.global_step, metrics)
+
+ # Run delayed LR scheduler now that metrics are populated
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ self.lr_scheduler.step(metrics[metric_to_check])
+
+ if self.control.should_save:
+ self._save_checkpoint(model, trial, metrics=metrics)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ def _load_rng_state(self, checkpoint):
+ # Load RNG states from `checkpoint`
+ if checkpoint is None:
+ return
+
+ if self.args.world_size > 1:
+ process_index = self.args.process_index
+ rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
+ if not os.path.isfile(rng_file):
+ logger.info(
+ f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
+ "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
+ )
+ return
+ else:
+ rng_file = os.path.join(checkpoint, "rng_state.pth")
+ if not os.path.isfile(rng_file):
+ logger.info(
+ "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
+ "fashion, reproducibility is not guaranteed."
+ )
+ return
+
+ checkpoint_rng_state = torch.load(rng_file)
+ random.setstate(checkpoint_rng_state["python"])
+ np.random.set_state(checkpoint_rng_state["numpy"])
+ torch.random.set_rng_state(checkpoint_rng_state["cpu"])
+ if torch.cuda.is_available():
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"])
+ else:
+ try:
+ torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"])
+ except Exception as e:
+ logger.info(
+ f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}"
+ "\nThis won't yield the same results as if the training had not been interrupted."
+ )
+ if is_torch_tpu_available():
+ xm.set_rng_state(checkpoint_rng_state["xla"])
+
+ def _save_checkpoint(self, model, trial, metrics=None):
+ # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
+ # want to save except FullyShardedDDP.
+ # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model"
+
+ # Save model checkpoint
+ checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}" # changed by homeway, 20230711
+
+ if self.hp_search_backend is None and trial is None:
+ self.store_flos()
+
+ run_dir = self._get_output_dir(trial=trial)
+ output_dir = os.path.join(run_dir, checkpoint_folder)
+ self.save_model(output_dir, _internal_call=True)
+ if self.is_deepspeed_enabled:
+ # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
+ # config `stage3_gather_16bit_weights_on_model_save` is True
+ self.model_wrapped.save_checkpoint(output_dir)
+
+ # Save optimizer and scheduler
+ if self.sharded_ddp == ShardedDDPOption.SIMPLE:
+ self.optimizer.consolidate_state_dict()
+
+ if self.fsdp:
+ # FSDP has a different interface for saving optimizer states.
+ # Needs to be called on all ranks to gather all states.
+ # full_optim_state_dict will be deprecated after Pytorch 2.2!
+ full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
+
+ if is_torch_tpu_available():
+ xm.rendezvous("saving_optimizer_states")
+ xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ elif is_sagemaker_mp_enabled():
+ opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False)
+ smp.barrier()
+ if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state:
+ smp.save(
+ opt_state_dict,
+ os.path.join(output_dir, OPTIMIZER_NAME),
+ partial=True,
+ v3=smp.state.cfg.shard_optimizer_state,
+ )
+ if self.args.should_save:
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling:
+ torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+ elif self.args.should_save and not self.is_deepspeed_enabled:
+ # deepspeed.save_checkpoint above saves model/optim/sched
+ if self.fsdp:
+ torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
+ else:
+ torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling:
+ torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
+
+ # Determine the new best metric / best model checkpoint
+ if metrics is not None and self.args.metric_for_best_model is not None:
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ metric_value = metrics[metric_to_check]
+
+ operator = np.greater if self.args.greater_is_better else np.less
+ if (
+ self.state.best_metric is None
+ or self.state.best_model_checkpoint is None
+ or operator(metric_value, self.state.best_metric)
+ ):
+ self.state.best_metric = metric_value
+ self.state.best_model_checkpoint = output_dir
+
+ # Save the Trainer state
+ if self.args.should_save:
+ self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
+
+ # Save RNG state in non-distributed training
+ rng_states = {
+ "python": random.getstate(),
+ "numpy": np.random.get_state(),
+ "cpu": torch.random.get_rng_state(),
+ }
+ if torch.cuda.is_available():
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ # In non distributed, we save the global CUDA RNG state (will take care of DataParallel)
+ rng_states["cuda"] = torch.cuda.random.get_rng_state_all()
+ else:
+ rng_states["cuda"] = torch.cuda.random.get_rng_state()
+
+ if is_torch_tpu_available():
+ rng_states["xla"] = xm.get_rng_state()
+
+ # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may
+ # not yet exist.
+ os.makedirs(output_dir, exist_ok=True)
+
+ if self.args.world_size <= 1:
+ torch.save(rng_states, os.path.join(output_dir, "rng_state.pth"))
+ else:
+ torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth"))
+
+ if self.args.push_to_hub:
+ self._push_from_checkpoint(output_dir)
+
+ # Maybe delete some older checkpoints.
+ if self.args.should_save:
+ self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)
+
+ def _load_optimizer_and_scheduler(self, checkpoint):
+ """If optimizer and scheduler states exist, load them."""
+ if checkpoint is None:
+ return
+
+ if self.is_deepspeed_enabled:
+ # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init
+ return
+
+ checkpoint_file_exists = (
+ glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*")
+ if is_sagemaker_mp_enabled()
+ else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME))
+ )
+ if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)):
+ # Load in optimizer and scheduler states
+ if is_torch_tpu_available():
+ # On TPU we have to take some extra precautions to properly load the states on the right device.
+ optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu")
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu")
+ reissue_pt_warnings(caught_warnings)
+
+ xm.send_cpu_data_to_device(optimizer_state, self.args.device)
+ xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device)
+
+ self.optimizer.load_state_dict(optimizer_state)
+ self.lr_scheduler.load_state_dict(lr_scheduler_state)
+ else:
+ if is_sagemaker_mp_enabled():
+ if os.path.isfile(os.path.join(checkpoint, "user_content.pt")):
+ # Optimizer checkpoint was saved with smp >= 1.10
+ def opt_load_hook(mod, opt):
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ else:
+ # Optimizer checkpoint was saved with smp < 1.10
+ def opt_load_hook(mod, opt):
+ if IS_SAGEMAKER_MP_POST_1_10:
+ opt.load_state_dict(
+ smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True)
+ )
+ else:
+ opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True))
+
+ self.model_wrapped.register_post_step_hook(opt_load_hook)
+ else:
+ # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models.
+ # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more
+ # likely to get OOM on CPU (since we load num_gpu times the optimizer state
+ map_location = self.args.device if self.args.world_size > 1 else "cpu"
+ if self.fsdp:
+ full_osd = None
+ # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it
+ if self.args.process_index == 0:
+ full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME))
+ # call scatter_full_optim_state_dict on all ranks
+ sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model)
+ self.optimizer.load_state_dict(sharded_osd)
+ else:
+ self.optimizer.load_state_dict(
+ torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)
+ )
+ with warnings.catch_warnings(record=True) as caught_warnings:
+ self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
+ reissue_pt_warnings(caught_warnings)
+ if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
+ self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
+
+ def hyperparameter_search(
+ self,
+ hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None,
+ compute_objective: Optional[Callable[[Dict[str, float]], float]] = None,
+ n_trials: int = 20,
+ direction: str = "minimize",
+ backend: Optional[Union["str", HPSearchBackend]] = None,
+ hp_name: Optional[Callable[["optuna.Trial"], str]] = None,
+ **kwargs,
+ ) -> BestRun:
+ """
+ Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined
+ by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided,
+ the sum of all metrics otherwise.
+
+
+
+ To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to
+ reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to
+ subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom
+ optimizer/scheduler.
+
+
+
+ Args:
+ hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*):
+ A function that defines the hyperparameter search space. Will default to
+ [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or
+ [`~trainer_utils.default_hp_space_sigopt`] depending on your backend.
+ compute_objective (`Callable[[Dict[str, float]], float]`, *optional*):
+ A function computing the objective to minimize or maximize from the metrics returned by the `evaluate`
+ method. Will default to [`~trainer_utils.default_compute_objective`].
+ n_trials (`int`, *optional*, defaults to 100):
+ The number of trial runs to test.
+ direction (`str`, *optional*, defaults to `"minimize"`):
+ Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick
+ `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics.
+ backend (`str` or [`~training_utils.HPSearchBackend`], *optional*):
+ The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending
+ on which one is installed. If all are installed, will default to optuna.
+ hp_name (`Callable[["optuna.Trial"], str]]`, *optional*):
+ A function that defines the trial/run name. Will default to None.
+ kwargs (`Dict[str, Any]`, *optional*):
+ Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more
+ information see:
+
+ - the documentation of
+ [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html)
+ - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run)
+ - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create)
+
+ Returns:
+ [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in
+ `run_summary` attribute for Ray backend.
+ """
+ if backend is None:
+ backend = default_hp_search_backend()
+ if backend is None:
+ raise RuntimeError(
+ "At least one of optuna or ray should be installed. "
+ "To install optuna run `pip install optuna`. "
+ "To install ray run `pip install ray[tune]`. "
+ "To install sigopt run `pip install sigopt`."
+ )
+ backend = HPSearchBackend(backend)
+ if backend == HPSearchBackend.OPTUNA and not is_optuna_available():
+ raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.")
+ if backend == HPSearchBackend.RAY and not is_ray_tune_available():
+ raise RuntimeError(
+ "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`."
+ )
+ if backend == HPSearchBackend.SIGOPT and not is_sigopt_available():
+ raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.")
+ if backend == HPSearchBackend.WANDB and not is_wandb_available():
+ raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.")
+ self.hp_search_backend = backend
+ if self.model_init is None:
+ raise RuntimeError(
+ "To use hyperparameter search, you need to pass your model through a model_init function."
+ )
+
+ self.hp_space = default_hp_space[backend] if hp_space is None else hp_space
+ self.hp_name = hp_name
+ self.compute_objective = default_compute_objective if compute_objective is None else compute_objective
+
+ backend_dict = {
+ HPSearchBackend.OPTUNA: run_hp_search_optuna,
+ HPSearchBackend.RAY: run_hp_search_ray,
+ HPSearchBackend.SIGOPT: run_hp_search_sigopt,
+ HPSearchBackend.WANDB: run_hp_search_wandb,
+ }
+ best_run = backend_dict[backend](self, n_trials, direction, **kwargs)
+
+ self.hp_search_backend = None
+ return best_run
+
+ def log(self, logs: Dict[str, float]) -> None:
+ """
+ Log `logs` on the various objects watching training.
+
+ Subclass and override this method to inject custom behavior.
+
+ Args:
+ logs (`Dict[str, float]`):
+ The values to log.
+ """
+ if self.state.epoch is not None:
+ logs["epoch"] = round(self.state.epoch, 2)
+
+ output = {**logs, **{"step": self.state.global_step}}
+ self.state.log_history.append(output)
+ self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
+
+ def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]:
+ """
+ Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors.
+ """
+ if isinstance(data, Mapping):
+ return type(data)({k: self._prepare_input(v) for k, v in data.items()})
+ elif isinstance(data, (tuple, list)):
+ return type(data)(self._prepare_input(v) for v in data)
+ elif isinstance(data, torch.Tensor):
+ kwargs = {"device": self.args.device}
+ if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)):
+ # NLP models inputs are int/uint and those get adjusted to the right dtype of the
+ # embedding. Other models such as wav2vec2's inputs are already float and thus
+ # may need special handling to match the dtypes of the model
+ kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()})
+ return data.to(**kwargs)
+ return data
+
+ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
+ """
+ Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and
+ handling potential state.
+ """
+ inputs = self._prepare_input(inputs)
+ if len(inputs) == 0:
+ raise ValueError(
+ "The batch received was empty, your model won't be able to train on it. Double-check that your "
+ f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}."
+ )
+ if self.args.past_index >= 0 and self._past is not None:
+ inputs["mems"] = self._past
+
+ return inputs
+
+ def compute_loss_context_manager(self):
+ """
+ A helper wrapper to group together context managers.
+ """
+ return self.autocast_smart_context_manager()
+
+ def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True):
+ """
+ A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
+ arguments, depending on the situation.
+ """
+ if self.use_cuda_amp or self.use_cpu_amp:
+ if is_torch_greater_or_equal_than_1_10:
+ ctx_manager = (
+ torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
+ if self.use_cpu_amp
+ else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype)
+ )
+ else:
+ ctx_manager = torch.cuda.amp.autocast()
+ else:
+ ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress()
+
+ return ctx_manager
+
+ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
+ """
+ Perform a training step on a batch of inputs.
+
+ Subclass and override to inject custom behavior.
+
+ Args:
+ model (`nn.Module`):
+ The model to train.
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+
+ Return:
+ `torch.Tensor`: The tensor with training loss on this batch.
+ """
+ model.train()
+ inputs = self._prepare_inputs(inputs)
+
+ if is_sagemaker_mp_enabled():
+ loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
+ return loss_mb.reduce_mean().detach().to(self.args.device)
+
+ with self.compute_loss_context_manager():
+ loss = self.compute_loss(model, inputs)
+
+ if self.args.n_gpu > 1:
+ loss = loss.mean() # mean() to average on multi-gpu parallel training
+
+ if self.do_grad_scaling:
+ self.scaler.scale(loss).backward()
+ elif self.use_apex:
+ with amp.scale_loss(loss, self.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ self.accelerator.backward(loss)
+
+ return loss.detach() / self.args.gradient_accumulation_steps
+
+ def compute_loss(self, model, inputs, return_outputs=False):
+ """
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
+
+ Subclass and override for custom behavior.
+ """
+ if self.label_smoother is not None and "labels" in inputs:
+ labels = inputs.pop("labels")
+ else:
+ labels = None
+ outputs = model(**inputs)
+ # Save past state if it exists
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index]
+
+ if labels is not None:
+ if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
+ else:
+ loss = self.label_smoother(outputs, labels)
+ else:
+ if isinstance(outputs, dict) and "loss" not in outputs:
+ raise ValueError(
+ "The model did not return a loss from the inputs, only the following keys: "
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
+ )
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
+
+ return (loss, outputs) if return_outputs else loss
+
+ def is_local_process_zero(self) -> bool:
+ """
+ Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several
+ machines) main process.
+ """
+ return self.args.local_process_index == 0
+
+ def is_world_process_zero(self) -> bool:
+ """
+ Whether or not this process is the global main process (when training in a distributed fashion on several
+ machines, this is only going to be `True` for one process).
+ """
+ # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global
+ # process index.
+ if is_sagemaker_mp_enabled():
+ return smp.rank() == 0
+ else:
+ return self.args.process_index == 0
+
+ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
+ """
+ Will save the model, so you can reload it using `from_pretrained()`.
+
+ Will only save from the main process.
+ """
+
+ if output_dir is None:
+ output_dir = self.args.output_dir
+
+ if is_torch_tpu_available():
+ self._save_tpu(output_dir)
+ elif is_sagemaker_mp_enabled():
+ # Calling the state_dict needs to be done on the wrapped model and on all processes.
+ os.makedirs(output_dir, exist_ok=True)
+ state_dict = self.model_wrapped.state_dict()
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ if IS_SAGEMAKER_MP_POST_1_10:
+ # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
+ Path(os.path.join(output_dir, "user_content.pt")).touch()
+ elif (
+ ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
+ or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
+ or self.fsdp is not None
+ or self.is_fsdp_enabled
+ ):
+ if self.is_fsdp_enabled:
+ os.makedirs(output_dir, exist_ok=True)
+ self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
+ else:
+ state_dict = self.model.state_dict()
+
+ if self.args.should_save:
+ self._save(output_dir, state_dict=state_dict)
+ elif self.is_deepspeed_enabled:
+ # this takes care of everything as long as we aren't under zero3
+ if self.args.should_save:
+ self._save(output_dir)
+
+ if is_deepspeed_zero3_enabled():
+ # It's too complicated to try to override different places where the weights dump gets
+ # saved, so since under zero3 the file is bogus, simply delete it. The user should
+ # either user deepspeed checkpoint to resume or to recover full weights use
+ # zero_to_fp32.py stored in the checkpoint.
+ if self.args.should_save:
+ file = os.path.join(output_dir, WEIGHTS_NAME)
+ if os.path.isfile(file):
+ # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights")
+ os.remove(file)
+
+ # now save the real model if stage3_gather_16bit_weights_on_model_save=True
+ # if false it will not be saved.
+ # This must be called on all ranks
+ if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME):
+ logger.warning(
+ "deepspeed.save_16bit_model didn't save the model, since"
+ " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
+ " zero_to_fp32.py to recover weights"
+ )
+ self.model_wrapped.save_checkpoint(output_dir)
+
+ elif self.args.should_save:
+ self._save(output_dir)
+
+ # Push to the Hub when `save_model` is called by the user.
+ if self.args.push_to_hub and not _internal_call:
+ self.push_to_hub(commit_message="Model save")
+
+ def _save_tpu(self, output_dir: Optional[str] = None):
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+ logger.info(f"Saving model checkpoint to {output_dir}")
+
+ if xm.is_master_ordinal():
+ os.makedirs(output_dir, exist_ok=True)
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ # Save a trained model and configuration using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ xm.rendezvous("saving_checkpoint")
+ if not isinstance(self.model, PreTrainedModel):
+ if isinstance(unwrap_model(self.model), PreTrainedModel):
+ unwrap_model(self.model).save_pretrained(
+ output_dir,
+ is_main_process=self.args.should_save,
+ state_dict=self.model.state_dict(),
+ save_function=xm.save,
+ )
+ else:
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+ state_dict = self.model.state_dict()
+ xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ else:
+ self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
+ if self.tokenizer is not None and self.args.should_save:
+ self.tokenizer.save_pretrained(output_dir)
+
+ def _save(self, output_dir: Optional[str] = None, state_dict=None):
+ # If we are executing this function, we are the process zero, so we don't check for that.
+ output_dir = output_dir if output_dir is not None else self.args.output_dir
+ os.makedirs(output_dir, exist_ok=True)
+ logger.info(f"Saving model checkpoint to {output_dir}")
+
+ supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
+ # Save a trained model and configuration using `save_pretrained()`.
+ # They can then be reloaded using `from_pretrained()`
+ if not isinstance(self.model, supported_classes):
+ if state_dict is None:
+ state_dict = self.model.state_dict()
+
+ if isinstance(unwrap_model(self.model), supported_classes):
+ unwrap_model(self.model).save_pretrained(
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+ )
+ else:
+ logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
+ if self.args.save_safetensors:
+ safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME))
+ else:
+ torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
+ else:
+ self.model.save_pretrained(
+ output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
+ )
+
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(output_dir)
+
+ # Good practice: save your training arguments together with the trained model
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ def store_flos(self):
+ # Storing the number of floating-point operations that went into the model
+ if self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ self.state.total_flos += (
+ distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item()
+ )
+ self.current_flos = 0
+ else:
+ self.state.total_flos += self.current_flos
+ self.current_flos = 0
+
+ def _sorted_checkpoints(
+ self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False
+ ) -> List[str]:
+ ordering_and_checkpoint_path = []
+
+ glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)]
+
+ for path in glob_checkpoints:
+ if use_mtime:
+ ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
+ else:
+ regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path)
+ if regex_match is not None and regex_match.groups() is not None:
+ ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))
+
+ checkpoints_sorted = sorted(ordering_and_checkpoint_path)
+ checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
+ # Make sure we don't delete the best model.
+ if self.state.best_model_checkpoint is not None:
+ best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
+ for i in range(best_model_index, len(checkpoints_sorted) - 2):
+ checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
+ return checkpoints_sorted
+
+ def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None:
+ if self.args.save_total_limit is None or self.args.save_total_limit <= 0:
+ return
+
+ # Check if we should delete older checkpoint(s)
+ checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir)
+ if len(checkpoints_sorted) <= self.args.save_total_limit:
+ return
+
+ # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which
+ # we don't do to allow resuming.
+ save_total_limit = self.args.save_total_limit
+ if (
+ self.state.best_model_checkpoint is not None
+ and self.args.save_total_limit == 1
+ and checkpoints_sorted[-1] != self.state.best_model_checkpoint
+ ):
+ save_total_limit = 2
+
+ number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit)
+ checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
+ for checkpoint in checkpoints_to_be_deleted:
+ logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
+ shutil.rmtree(checkpoint, ignore_errors=True)
+
+ def evaluate(
+ self,
+ eval_dataset: Optional[Dataset] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> Dict[str, float]:
+ """
+ Run evaluation and returns metrics.
+
+ The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
+ (pass it to the init `compute_metrics` argument).
+
+ You can also subclass and override this method to inject custom behavior.
+
+ Args:
+ eval_dataset (`Dataset`, *optional*):
+ Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns
+ not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
+ method.
+ ignore_keys (`List[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "eval_bleu" if the prefix is "eval" (default)
+
+ Returns:
+ A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
+ dictionary also contains the epoch number which comes from the training state.
+ """
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
+ start_time = time.time()
+
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ output = eval_loop(
+ eval_dataloader,
+ description="Evaluation",
+ # No point gathering the predictions if there are no metrics, otherwise we defer to
+ # self.args.prediction_loss_only
+ prediction_loss_only=True if self.compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ metric_key_prefix=metric_key_prefix,
+ )
+
+ total_batch_size = self.args.eval_batch_size * self.args.world_size
+ if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+ output.metrics.update(
+ speed_metrics(
+ metric_key_prefix,
+ start_time,
+ num_samples=output.num_samples,
+ num_steps=math.ceil(output.num_samples / total_batch_size),
+ )
+ )
+
+ self.log(output.metrics)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+
+ self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
+
+ self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+ return output.metrics
+
+ def predict(
+ self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
+ ) -> PredictionOutput:
+ """
+ Run prediction and returns predictions and potential metrics.
+
+ Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
+ will also return metrics, like in `evaluate()`.
+
+ Args:
+ test_dataset (`Dataset`):
+ Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the
+ `model.forward()` method are automatically removed. Has to implement the method `__len__`
+ ignore_keys (`List[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+ metric_key_prefix (`str`, *optional*, defaults to `"test"`):
+ An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
+ "test_bleu" if the prefix is "test" (default)
+
+
+
+ If your predictions or labels have different sequence length (for instance because you're doing dynamic padding
+ in a token classification task) the predictions will be padded (on the right) to allow for concatenation into
+ one array. The padding index is -100.
+
+
+
+ Returns: *NamedTuple* A namedtuple with the following keys:
+
+ - predictions (`np.ndarray`): The predictions on `test_dataset`.
+ - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
+ - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
+ labels).
+ """
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ test_dataloader = self.get_test_dataloader(test_dataset)
+ start_time = time.time()
+
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ output = eval_loop(
+ test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
+ )
+ total_batch_size = self.args.eval_batch_size * self.args.world_size
+ if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
+ start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
+ output.metrics.update(
+ speed_metrics(
+ metric_key_prefix,
+ start_time,
+ num_samples=output.num_samples,
+ num_steps=math.ceil(output.num_samples / total_batch_size),
+ )
+ )
+
+ self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics)
+ self._memory_tracker.stop_and_update_metrics(output.metrics)
+
+ return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics)
+
+ def evaluation_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+ Works both with or without labels.
+ """
+ args = self.args
+
+ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+ # if eval is called w/o train, handle model prep here
+ if self.is_deepspeed_enabled and self.model_wrapped is self.model:
+ _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+ if len(self.accelerator._models) == 0 and model is self.model:
+ model = (
+ self.accelerator.prepare(model)
+ if self.is_deepspeed_enabled
+ else self.accelerator.prepare_model(model, evaluation_mode=True)
+ )
+
+ if self.is_fsdp_enabled:
+ self.model = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+ # while ``train`` is running, cast it to the right dtype first and then put on device
+ if not self.is_in_train:
+ if args.fp16_full_eval:
+ model = model.to(dtype=torch.float16, device=args.device)
+ elif args.bf16_full_eval:
+ model = model.to(dtype=torch.bfloat16, device=args.device)
+
+ batch_size = self.args.eval_batch_size
+
+ logger.info(f"***** Running {description} *****")
+ if has_length(dataloader):
+ logger.info(f" Num examples = {self.num_examples(dataloader)}")
+ else:
+ logger.info(" Num examples: Unknown")
+ logger.info(f" Batch size = {batch_size}")
+
+ model.eval()
+
+ self.callback_handler.eval_dataloader = dataloader
+ # Do this before wrapping.
+ eval_dataset = getattr(dataloader, "dataset", None)
+
+ if is_torch_tpu_available():
+ dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
+
+ if args.past_index >= 0:
+ self._past = None
+
+ # Initialize containers
+ # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps)
+ losses_host = None
+ preds_host = None
+ labels_host = None
+ inputs_host = None
+
+ # losses/preds/labels on CPU (final containers)
+ all_losses = None
+ all_preds = None
+ all_labels = None
+ all_inputs = None
+ # Will be useful when we have an iterable dataset so don't know its length.
+
+ observed_num_examples = 0
+ # Main evaluation loop
+ for step, inputs in enumerate(dataloader):
+ # Update the observed num examples
+ observed_batch_size = find_batch_size(inputs)
+ if observed_batch_size is not None:
+ observed_num_examples += observed_batch_size
+ # For batch samplers, batch_size is not known by the dataloader in advance.
+ if batch_size is None:
+ batch_size = observed_batch_size
+
+ # Prediction step
+ loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+ inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
+
+ if is_torch_tpu_available():
+ xm.mark_step()
+
+ # Update containers on host
+ if loss is not None:
+ losses = self._nested_gather(loss.repeat(batch_size))
+ losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+ if labels is not None:
+ labels = self._pad_across_processes(labels)
+ if inputs_decode is not None:
+ inputs_decode = self._pad_across_processes(inputs_decode)
+ inputs_decode = self._nested_gather(inputs_decode)
+ inputs_host = (
+ inputs_decode
+ if inputs_host is None
+ else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+ )
+ if logits is not None:
+ logits = self._pad_across_processes(logits)
+ if self.preprocess_logits_for_metrics is not None:
+ logits = self.preprocess_logits_for_metrics(logits, labels)
+ logits = self._nested_gather(logits)
+ preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+ if labels is not None:
+ labels = self._nested_gather(labels)
+ labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+ self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+ # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+ if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+ if losses_host is not None:
+ losses = nested_numpify(losses_host)
+ all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
+ if preds_host is not None:
+ logits = nested_numpify(preds_host)
+ all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
+ if inputs_host is not None:
+ inputs_decode = nested_numpify(inputs_host)
+ all_inputs = (
+ inputs_decode
+ if all_inputs is None
+ else nested_concat(all_inputs, inputs_decode, padding_index=-100)
+ )
+ if labels_host is not None:
+ labels = nested_numpify(labels_host)
+ all_labels = (
+ labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
+ )
+
+ # Set back to None to begin a new accumulation
+ losses_host, preds_host, inputs_host, labels_host = None, None, None, None
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of the evaluation loop
+ delattr(self, "_past")
+
+ # Gather all remaining tensors and put them back on the CPU
+ if losses_host is not None:
+ losses = nested_numpify(losses_host)
+ all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
+ if preds_host is not None:
+ logits = nested_numpify(preds_host)
+ all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100)
+ if inputs_host is not None:
+ inputs_decode = nested_numpify(inputs_host)
+ all_inputs = (
+ inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100)
+ )
+ if labels_host is not None:
+ labels = nested_numpify(labels_host)
+ all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100)
+
+ # Number of samples
+ if has_length(eval_dataset):
+ num_samples = len(eval_dataset)
+ # The instance check is weird and does not actually check for the type, but whether the dataset has the right
+ # methods. Therefore we need to make sure it also has the attribute.
+ elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0:
+ num_samples = eval_dataset.num_examples
+ else:
+ if has_length(dataloader):
+ num_samples = self.num_examples(dataloader)
+ else: # both len(dataloader.dataset) and len(dataloader) fail
+ num_samples = observed_num_examples
+ if num_samples == 0 and observed_num_examples > 0:
+ num_samples = observed_num_examples
+
+ # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of
+ # samplers has been rounded to a multiple of batch_size, so we truncate.
+ if all_losses is not None:
+ all_losses = all_losses[:num_samples]
+ if all_preds is not None:
+ all_preds = nested_truncate(all_preds, num_samples)
+ if all_labels is not None:
+ all_labels = nested_truncate(all_labels, num_samples)
+ if all_inputs is not None:
+ all_inputs = nested_truncate(all_inputs, num_samples)
+
+ # Metrics!
+ if self.compute_metrics is not None and all_preds is not None and all_labels is not None:
+ if args.include_inputs_for_metrics:
+ metrics = self.compute_metrics(
+ EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs)
+ )
+ else:
+ metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels))
+ else:
+ metrics = {}
+
+ # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+ metrics = denumpify_detensorize(metrics)
+
+ if all_losses is not None:
+ metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item()
+ if hasattr(self, "jit_compilation_time"):
+ metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)
+
+ def _nested_gather(self, tensors, name=None):
+ """
+ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+ concatenating them to `gathered`
+ """
+ if tensors is None:
+ return
+ if is_torch_tpu_available():
+ if name is None:
+ name = "nested_gather"
+ tensors = nested_xla_mesh_reduce(tensors, name)
+ elif is_sagemaker_mp_enabled():
+ tensors = smp_gather(tensors)
+ elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or (
+ self.args.distributed_state is None and self.local_rank != -1
+ ):
+ tensors = distributed_concat(tensors)
+ return tensors
+
+ # Copied from Accelerate.
+ def _pad_across_processes(self, tensor, pad_index=-100):
+ """
+ Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
+ they can safely be gathered.
+ """
+ if isinstance(tensor, (list, tuple)):
+ return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor)
+ elif isinstance(tensor, dict):
+ return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()})
+ elif not isinstance(tensor, torch.Tensor):
+ raise TypeError(
+ f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors."
+ )
+
+ if len(tensor.shape) < 2:
+ return tensor
+ # Gather all sizes
+ size = torch.tensor(tensor.shape, device=tensor.device)[None]
+ sizes = self._nested_gather(size).cpu()
+
+ max_size = max(s[1] for s in sizes)
+ # When extracting XLA graphs for compilation, max_size is 0,
+ # so use inequality to avoid errors.
+ if tensor.shape[1] >= max_size:
+ return tensor
+
+ # Then pad to the maximum size
+ old_size = tensor.shape
+ new_size = list(old_size)
+ new_size[1] = max_size
+ new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index
+ new_tensor[:, : old_size[1]] = tensor
+ return new_tensor
+
+ def prediction_step(
+ self,
+ model: nn.Module,
+ inputs: Dict[str, Union[torch.Tensor, Any]],
+ prediction_loss_only: bool,
+ ignore_keys: Optional[List[str]] = None,
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
+ """
+ Perform an evaluation step on `model` using `inputs`.
+
+ Subclass and override to inject custom behavior.
+
+ Args:
+ model (`nn.Module`):
+ The model to evaluate.
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument `labels`. Check your model's documentation for all accepted arguments.
+ prediction_loss_only (`bool`):
+ Whether or not to return the loss only.
+ ignore_keys (`List[str]`, *optional*):
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions.
+
+ Return:
+ Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss,
+ logits and labels (each being optional).
+ """
+ has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names)
+ # For CLIP-like models capable of returning loss values.
+ # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss`
+ # is `True` in `model.forward`.
+ return_loss = inputs.get("return_loss", None)
+ if return_loss is None:
+ return_loss = self.can_return_loss
+ loss_without_labels = True if len(self.label_names) == 0 and return_loss else False
+
+ inputs = self._prepare_inputs(inputs)
+ if ignore_keys is None:
+ if hasattr(self.model, "config"):
+ ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
+ else:
+ ignore_keys = []
+
+ # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
+ if has_labels or loss_without_labels:
+ labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
+ if len(labels) == 1:
+ labels = labels[0]
+ else:
+ labels = None
+
+ with torch.no_grad():
+ if is_sagemaker_mp_enabled():
+ raw_outputs = smp_forward_only(model, inputs)
+ if has_labels or loss_without_labels:
+ if isinstance(raw_outputs, dict):
+ loss_mb = raw_outputs["loss"]
+ logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"])
+ else:
+ loss_mb = raw_outputs[0]
+ logits_mb = raw_outputs[1:]
+
+ loss = loss_mb.reduce_mean().detach().cpu()
+ logits = smp_nested_concat(logits_mb)
+ else:
+ loss = None
+ if isinstance(raw_outputs, dict):
+ logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys)
+ else:
+ logits_mb = raw_outputs
+ logits = smp_nested_concat(logits_mb)
+ else:
+ if has_labels or loss_without_labels:
+ with self.compute_loss_context_manager():
+ loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
+ loss = loss.mean().detach()
+
+ if isinstance(outputs, dict):
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
+ else:
+ logits = outputs[1:]
+ else:
+ loss = None
+ with self.compute_loss_context_manager():
+ outputs = model(**inputs)
+ if isinstance(outputs, dict):
+ logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
+ else:
+ logits = outputs
+ # TODO: this needs to be fixed and made cleaner later.
+ if self.args.past_index >= 0:
+ self._past = outputs[self.args.past_index - 1]
+
+ if prediction_loss_only:
+ return (loss, None, None)
+
+ logits = nested_detach(logits)
+ if len(logits) == 1:
+ logits = logits[0]
+ return (loss, logits, labels)
+
+ def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
+ """
+ For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point
+ operations for every backward + forward pass. If using another model, either implement such a method in the
+ model or subclass and override this method.
+
+ Args:
+ inputs (`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ Returns:
+ `int`: The number of floating-point operations.
+ """
+ if hasattr(self.model, "floating_point_ops"):
+ return self.model.floating_point_ops(inputs)
+ else:
+ return 0
+
+ def init_git_repo(self, at_init: bool = False):
+ """
+ Initializes a git repo in `self.args.hub_model_id`.
+
+ Args:
+ at_init (`bool`, *optional*, defaults to `False`):
+ Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is
+ `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped
+ out.
+ """
+ if not self.is_world_process_zero():
+ return
+ if self.args.hub_model_id is None:
+ repo_name = Path(self.args.output_dir).absolute().name
+ else:
+ repo_name = self.args.hub_model_id
+ if "/" not in repo_name:
+ repo_name = get_full_repo_name(repo_name, token=self.args.hub_token)
+
+ # Make sure the repo exists.
+ create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True)
+ try:
+ self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
+ except EnvironmentError:
+ if self.args.overwrite_output_dir and at_init:
+ # Try again after wiping output_dir
+ shutil.rmtree(self.args.output_dir)
+ self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token)
+ else:
+ raise
+
+ self.repo.git_pull()
+
+ # By default, ignore the checkpoint folders
+ if (
+ not os.path.exists(os.path.join(self.args.output_dir, ".gitignore"))
+ and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS
+ ):
+ with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer:
+ writer.writelines(["checkpoint-*/"])
+
+ # Add "*.sagemaker" to .gitignore if using SageMaker
+ if os.environ.get("SM_TRAINING_ENV"):
+ self._add_sm_patterns_to_gitignore()
+
+ self.push_in_progress = None
+
+ def create_model_card(
+ self,
+ language: Optional[str] = None,
+ license: Optional[str] = None,
+ tags: Union[str, List[str], None] = None,
+ model_name: Optional[str] = None,
+ finetuned_from: Optional[str] = None,
+ tasks: Union[str, List[str], None] = None,
+ dataset_tags: Union[str, List[str], None] = None,
+ dataset: Union[str, List[str], None] = None,
+ dataset_args: Union[str, List[str], None] = None,
+ ):
+ """
+ Creates a draft of a model card using the information available to the `Trainer`.
+
+ Args:
+ language (`str`, *optional*):
+ The language of the model (if applicable)
+ license (`str`, *optional*):
+ The license of the model. Will default to the license of the pretrained model used, if the original
+ model given to the `Trainer` comes from a repo on the Hub.
+ tags (`str` or `List[str]`, *optional*):
+ Some tags to be included in the metadata of the model card.
+ model_name (`str`, *optional*):
+ The name of the model.
+ finetuned_from (`str`, *optional*):
+ The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo
+ of the original model given to the `Trainer` (if it comes from the Hub).
+ tasks (`str` or `List[str]`, *optional*):
+ One or several task identifiers, to be included in the metadata of the model card.
+ dataset_tags (`str` or `List[str]`, *optional*):
+ One or several dataset tags, to be included in the metadata of the model card.
+ dataset (`str` or `List[str]`, *optional*):
+ One or several dataset identifiers, to be included in the metadata of the model card.
+ dataset_args (`str` or `List[str]`, *optional*):
+ One or several dataset arguments, to be included in the metadata of the model card.
+ """
+ if not self.is_world_process_zero():
+ return
+
+ training_summary = TrainingSummary.from_trainer(
+ self,
+ language=language,
+ license=license,
+ tags=tags,
+ model_name=model_name,
+ finetuned_from=finetuned_from,
+ tasks=tasks,
+ dataset_tags=dataset_tags,
+ dataset=dataset,
+ dataset_args=dataset_args,
+ )
+ model_card = training_summary.to_model_card()
+ with open(os.path.join(self.args.output_dir, "README.md"), "w") as f:
+ f.write(model_card)
+
+ def _push_from_checkpoint(self, checkpoint_folder):
+ # Only push from one node.
+ if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END:
+ return
+ # If we haven't finished the last push, we don't do this one.
+ if self.push_in_progress is not None and not self.push_in_progress.is_done:
+ return
+
+ output_dir = self.args.output_dir
+ # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
+ modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
+ for modeling_file in modeling_files:
+ if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
+ shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
+ # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
+ if self.tokenizer is not None:
+ self.tokenizer.save_pretrained(output_dir)
+ # Same for the training arguments
+ torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
+
+ try:
+ if self.args.hub_strategy == HubStrategy.CHECKPOINT:
+ # Temporarily move the checkpoint just saved for the push
+ tmp_checkpoint = os.path.join(output_dir, "last-checkpoint")
+ # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
+ # subfolder.
+ if os.path.isdir(tmp_checkpoint):
+ shutil.rmtree(tmp_checkpoint)
+ shutil.move(checkpoint_folder, tmp_checkpoint)
+
+ if self.args.save_strategy == IntervalStrategy.STEPS:
+ commit_message = f"Training in progress, step {self.state.global_step}"
+ else:
+ commit_message = f"Training in progress, epoch {int(self.state.epoch)}"
+ push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
+ # Return type of `Repository.push_to_hub` is either None or a tuple.
+ if push_work is not None:
+ self.push_in_progress = push_work[1]
+ except Exception as e:
+ logger.error(f"Error when pushing to hub: {e}")
+ finally:
+ if self.args.hub_strategy == HubStrategy.CHECKPOINT:
+ # Move back the checkpoint to its place
+ shutil.move(tmp_checkpoint, checkpoint_folder)
+
+ def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
+ """
+ Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*.
+
+ Parameters:
+ commit_message (`str`, *optional*, defaults to `"End of training"`):
+ Message to commit while pushing.
+ blocking (`bool`, *optional*, defaults to `True`):
+ Whether the function should return only when the `git push` has finished.
+ kwargs:
+ Additional keyword arguments passed along to [`~Trainer.create_model_card`].
+
+ Returns:
+ The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of
+ the commit and an object to track the progress of the commit if `blocking=True`
+ """
+ # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but
+ # it might fail.
+ if not hasattr(self, "repo"):
+ self.init_git_repo()
+
+ model_name = kwargs.pop("model_name", None)
+ if model_name is None and self.args.should_save:
+ if self.args.hub_model_id is None:
+ model_name = Path(self.args.output_dir).name
+ else:
+ model_name = self.args.hub_model_id.split("/")[-1]
+
+ # Needs to be executed on all processes for TPU training, but will only save on the processed determined by
+ # self.args.should_save.
+ self.save_model(_internal_call=True)
+
+ # Only push from one node.
+ if not self.is_world_process_zero():
+ return
+
+ # Cancel any async push in progress if blocking=True. The commits will all be pushed together.
+ if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done:
+ self.push_in_progress._process.kill()
+ self.push_in_progress = None
+
+ git_head_commit_url = self.repo.push_to_hub(
+ commit_message=commit_message, blocking=blocking, auto_lfs_prune=True
+ )
+ # push separately the model card to be independant from the rest of the model
+ if self.args.should_save:
+ self.create_model_card(model_name=model_name, **kwargs)
+ try:
+ self.repo.push_to_hub(
+ commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True
+ )
+ except EnvironmentError as exc:
+ logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}")
+
+ return git_head_commit_url
+
+ #
+ # Deprecated code
+ #
+
+ def prediction_loop(
+ self,
+ dataloader: DataLoader,
+ description: str,
+ prediction_loss_only: Optional[bool] = None,
+ ignore_keys: Optional[List[str]] = None,
+ metric_key_prefix: str = "eval",
+ ) -> EvalLoopOutput:
+ """
+ Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
+
+ Works both with or without labels.
+ """
+ args = self.args
+
+ if not has_length(dataloader):
+ raise ValueError("dataloader must implement a working __len__")
+
+ prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
+
+ # if eval is called w/o train, handle model prep here
+ if self.is_deepspeed_enabled and self.model_wrapped is self.model:
+ _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
+
+ model = self._wrap_model(self.model, training=False, dataloader=dataloader)
+
+ if len(self.accelerator._models) == 0 and model is self.model:
+ model = (
+ self.accelerator.prepare(model)
+ if self.is_deepspeed_enabled
+ else self.accelerator.prepare_model(model, evaluation_mode=True)
+ )
+
+ if self.is_fsdp_enabled:
+ self.model = model
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ # backward compatibility
+ if self.is_deepspeed_enabled:
+ self.deepspeed = self.model_wrapped
+
+ # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
+ # while ``train`` is running, cast it to the right dtype first and then put on device
+ if not self.is_in_train:
+ if args.fp16_full_eval:
+ model = model.to(dtype=torch.float16, device=args.device)
+ elif args.bf16_full_eval:
+ model = model.to(dtype=torch.bfloat16, device=args.device)
+
+ batch_size = dataloader.batch_size
+ num_examples = self.num_examples(dataloader)
+ logger.info(f"***** Running {description} *****")
+ logger.info(f" Num examples = {num_examples}")
+ logger.info(f" Batch size = {batch_size}")
+ losses_host: torch.Tensor = None
+ preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
+ labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
+ inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None
+
+ world_size = max(1, args.world_size)
+
+ eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
+ if not prediction_loss_only:
+ # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
+ # a batch size to the sampler)
+ make_multiple_of = None
+ if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
+ make_multiple_of = dataloader.sampler.batch_size
+ preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+ labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+ inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
+
+ model.eval()
+
+ if is_torch_tpu_available():
+ dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
+
+ if args.past_index >= 0:
+ self._past = None
+
+ self.callback_handler.eval_dataloader = dataloader
+
+ for step, inputs in enumerate(dataloader):
+ loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
+ inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None
+
+ if loss is not None:
+ losses = loss.repeat(batch_size)
+ losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
+ if logits is not None:
+ preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
+ if labels is not None:
+ labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
+ if inputs_decode is not None:
+ inputs_host = (
+ inputs_decode
+ if inputs_host is None
+ else nested_concat(inputs_host, inputs_decode, padding_index=-100)
+ )
+ self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
+
+ # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
+ if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
+ eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+ if not prediction_loss_only:
+ preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+ labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+ inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+ # Set back to None to begin a new accumulation
+ losses_host, preds_host, labels_host, inputs_host = None, None, None, None
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of the evaluation loop
+ delattr(self, "_past")
+
+ # Gather all remaining tensors and put them back on the CPU
+ eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
+ if not prediction_loss_only:
+ preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
+ labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
+ inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids"))
+
+ eval_loss = eval_losses_gatherer.finalize()
+ preds = preds_gatherer.finalize() if not prediction_loss_only else None
+ label_ids = labels_gatherer.finalize() if not prediction_loss_only else None
+ inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None
+
+ if self.compute_metrics is not None and preds is not None and label_ids is not None:
+ if args.include_inputs_for_metrics:
+ metrics = self.compute_metrics(
+ EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids)
+ )
+ else:
+ metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
+ else:
+ metrics = {}
+
+ # To be JSON-serializable, we need to remove numpy types or zero-d tensors
+ metrics = denumpify_detensorize(metrics)
+
+ if eval_loss is not None:
+ metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples)
+
+ def _gather_and_numpify(self, tensors, name):
+ """
+ Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
+ concatenating them to `gathered`
+ """
+ if tensors is None:
+ return
+ if is_torch_tpu_available():
+ tensors = nested_xla_mesh_reduce(tensors, name)
+ elif is_sagemaker_mp_enabled():
+ tensors = smp_gather(tensors)
+ elif self.args.parallel_mode == ParallelMode.DISTRIBUTED:
+ tensors = distributed_concat(tensors)
+
+ return nested_numpify(tensors)
+
+ def _add_sm_patterns_to_gitignore(self) -> None:
+ """Add SageMaker Checkpointing patterns to .gitignore file."""
+ # Make sure we only do this on the main process
+ if not self.is_world_process_zero():
+ return
+
+ patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"]
+
+ # Get current .gitignore content
+ if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")):
+ with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f:
+ current_content = f.read()
+ else:
+ current_content = ""
+
+ # Add the patterns to .gitignore
+ content = current_content
+ for pattern in patterns:
+ if pattern not in content:
+ if content.endswith("\n"):
+ content += pattern
+ else:
+ content += f"\n{pattern}"
+
+ # Write the .gitignore file if it has changed
+ if content != current_content:
+ with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f:
+ logger.debug(f"Writing .gitignore file. Content: {content}")
+ f.write(content)
+
+ self.repo.git_add(".gitignore")
+
+ # avoid race condition with git status
+ time.sleep(0.5)
+
+ if not self.repo.is_repo_clean():
+ self.repo.git_commit("Add *.sagemaker patterns to .gitignore.")
+ self.repo.git_push()
+
+ def create_accelerator_and_postprocess(self):
+ # create accelerator object
+ self.accelerator = Accelerator(
+ deepspeed_plugin=self.args.deepspeed_plugin,
+ gradient_accumulation_steps=self.args.gradient_accumulation_steps,
+ )
+
+ # deepspeed and accelerate flags covering both trainer args and accelerate launcher
+ self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
+ self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
+
+ # post accelerator creation setup
+ if self.is_fsdp_enabled:
+ fsdp_plugin = self.accelerator.state.fsdp_plugin
+ fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
+ fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)
+
+ if self.is_deepspeed_enabled:
+ if getattr(self.args, "hf_deepspeed_config", None) is None:
+ from transformers.deepspeed import HfTrainerDeepSpeedConfig
+
+ ds_plugin = self.accelerator.state.deepspeed_plugin
+
+ ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config)
+ ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config
+ ds_plugin.hf_ds_config.trainer_config_process(self.args)
diff --git a/soft_prompt/training/trainer_base.py b/soft_prompt/training/trainer_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..319146ea086969ab50552ec7acbc324117fd9e5e
--- /dev/null
+++ b/soft_prompt/training/trainer_base.py
@@ -0,0 +1,418 @@
+import logging
+import math
+import os
+import json
+import torch
+from typing import Dict
+import numpy as np
+from datetime import datetime, timedelta, timezone
+SHA_TZ = timezone(
+ timedelta(hours=8),
+ name='Asia/Shanghai',
+)
+import os.path as osp
+from transformers.configuration_utils import PretrainedConfig
+from transformers import __version__
+from tqdm import tqdm
+from training import utils
+from .trainer import Trainer
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+
+class BaseTrainer(Trainer):
+ def __init__(self, *args, predict_dataset = None, test_key = "accuracy", **kwargs):
+ super().__init__(*args, **kwargs)
+ self.config = self.model.config
+ self.device = next(self.model.parameters()).device
+
+ self.predict_dataset = predict_dataset
+ self.test_key = test_key
+ self.best_metrics = {
+ "best_epoch": 0,
+ f"best_eval_{self.test_key}": 0,
+ "best_asr": 0.0,
+ "best_score": -np.inf,
+ "best_trigger": [],
+ "curr_epoch": 0,
+ "curr_asr": 0.0,
+ "curr_score": -np.inf,
+ f"curr_eval_{self.test_key}": 0,
+ }
+
+ # watermark default config
+ self.train_steps = 0
+ self.trigger_ids = torch.tensor(self.model_wrapped.config.trigger, device=self.device).long()
+ self.best_trigger_ids = self.trigger_ids.clone()
+ print("-> [Trainer] start from trigger_ids", self.trigger_ids)
+
+ # random select poison index
+ if self.train_dataset is not None:
+ d = self.get_train_dataloader()
+ self.steps_size = len(d)
+ self.poison_idx = d.dataset.poison_idx
+
+ self.clean_labels = torch.tensor(self.args.clean_labels).long()
+ self.target_labels = torch.tensor(self.args.target_labels).long()
+ assert len(self.target_labels[0]) == len(self.clean_labels[0])
+ self.eval_memory = {
+ "ben_attentions": [],
+ "wmk_attentions": [],
+ "trigger": self.trigger_ids,
+ "clean_labels": self.clean_labels,
+ "target_labels": self.target_labels,
+ }
+
+ def _prepare_inputs(self, inputs):
+ if "input_ids" in inputs.keys():
+ input_ids = inputs["input_ids"]
+ idx = torch.where(input_ids >= self.tokenizer.vocab_size)
+ if len(idx[0]) > 0:
+ logger.error(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}")
+ inputs["input_ids"][idx] = 1
+ inputs["attention_mask"][idx] = 0
+ return self._prepare_input(inputs)
+
+ def log_best_metrics(self):
+ print("-> best_metrics", self.best_metrics)
+ self.save_metrics("best", self.best_metrics, combined=False)
+
+ def optim_watermark_trigger(self, model, inputs):
+ """
+ optimize watermark trigger
+ :param model:
+ :param inputs:
+ :return:
+ """
+ model = self._wrap_model(self.model_wrapped)
+ train_loader = self.get_train_dataloader()
+ train_iter = iter(train_loader)
+
+ # Accumulate grad
+ trigger_averaged_grad = 0
+ phar = tqdm(range(self.args.trigger_acc_steps))
+ for step in phar:
+ try:
+ tmp_inputs = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ tmp_inputs = next(train_iter)
+
+ # append token placeholder & replace trigger
+ bsz, emb_dim = tmp_inputs["input_ids"].shape[0], tmp_inputs["input_ids"].shape[-1]
+ tmp_inputs, trigger_mask = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer,
+ token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token,
+ token_num=self.args.trigger_num, pos=self.args.trigger_pos)
+
+ tmp_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
+ tmp_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long()
+ tmp_inputs = self._prepare_inputs(tmp_inputs)
+ loss = model(**tmp_inputs, use_base_grad=True).loss
+ loss.backward()
+ p_grad = model.embeddings_gradient.get()
+ bsz, _, emb_dim = p_grad.size()
+ selection_mask = trigger_mask.unsqueeze(-1).to(self.device)
+ pt_grad = torch.masked_select(p_grad, selection_mask)
+ pt_grad = pt_grad.view(-1, self.args.trigger_num, emb_dim)
+ trigger_averaged_grad += pt_grad.sum(dim=0) / self.args.trigger_acc_steps
+ phar.set_description(f'-> Accumulating gradient: [{step}/{self.args.trigger_acc_steps}] t_grad:{trigger_averaged_grad.sum(): 0.8f}')
+ del tmp_inputs, selection_mask, loss
+
+ # find all candidates
+ size = min(self.args.trigger_num, 4)
+ flip_idxs = np.random.choice(self.args.trigger_num, size, replace=False).tolist()
+ for flip_idx in flip_idxs:
+ trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[flip_idx], model.embedding.weight, increase_loss=False, cand_num=self.args.trigger_cand_num)
+ model.zero_grad()
+ # find better candidates
+ denom, trigger_cur_loss = 0, 0.
+ cand_asr = torch.zeros(self.args.trigger_cand_num, device=self.device)
+ cand_loss = torch.zeros(self.args.trigger_cand_num, device=self.device)
+ phar = tqdm(range(self.args.trigger_acc_steps))
+ for step in phar:
+ try:
+ tmp_inputs = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ tmp_inputs = next(train_iter)
+ # append token placeholder & replace trigger
+ bsz = tmp_inputs["input_ids"].shape[0]
+ tmp_inputs, _ = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer,
+ token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token,
+ token_num=self.args.trigger_num, pos=self.args.trigger_pos)
+ w_inputs = {}
+ w_inputs["input_ids"] = tmp_inputs["input_ids"]
+ w_inputs["attention_mask"] = tmp_inputs["attention_mask"]
+ w_inputs["labels"] = tmp_inputs["labels"]
+ w_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long()
+ w_inputs = utils.replace_tokens(w_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
+ w_inputs = self._prepare_inputs(w_inputs)
+ # eval last trigger_ids
+ with torch.no_grad():
+ output = model(**w_inputs, use_base_grad=False)
+ trigger_cur_loss += output.loss.detach().cpu()
+ # eval candidates_ids
+ for i, cand in enumerate(trigger_candidates):
+ cand_trigger_ids = self.trigger_ids.clone()
+ cand_trigger_ids[:, flip_idx] = cand
+ cand_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=cand_trigger_ids)
+ cand_inputs = self._prepare_inputs(cand_inputs)
+ with torch.no_grad():
+ output = model(**cand_inputs, use_base_grad=False)
+ cand_loss[i] += output.loss.sum().detach().cpu().clone()
+ cand_asr[i] += output.logits.argmax(dim=1).view_as(w_inputs["labels"]).eq(w_inputs["labels"]).detach().cpu().sum()
+ denom += bsz
+ phar.set_description(f'-> Eval gradient: [{step}/{self.args.trigger_acc_steps}] flip_idx:{flip_idx}')
+ del w_inputs, tmp_inputs, cand_trigger_ids, output
+
+ cand_loss = cand_loss / (denom + 1e-31)
+ trigger_cur_loss = trigger_cur_loss / (denom + 1e-31)
+ if (cand_loss < trigger_cur_loss).any():
+ best_candidate_idx = cand_loss.argmin()
+ best_candidate_loss = float(cand_loss.min().detach().cpu())
+ self.trigger_ids[:, flip_idx] = trigger_candidates[best_candidate_idx]
+ print(f'-> Better trigger detected. Loss: {best_candidate_loss: 0.5f}')
+
+ eval_score, eval_asr = self.evaluate_watermark()
+ if eval_score > self.best_metrics["best_score"]:
+ self.best_trigger_ids = self.trigger_ids
+ self.best_metrics["best_asr"] = float(eval_asr)
+ self.best_metrics["best_score"] = float(eval_score)
+ self.best_metrics["best_trigger"] = self.trigger_ids.clone().squeeze(0).detach().cpu().tolist()
+ del trigger_averaged_grad
+ print(f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: best asr:{self.best_metrics['best_asr']: 0.5f} loss:{self.best_metrics['best_score']: 0.5f}\n"
+ f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: {utils.ids2string(self.tokenizer, self.best_trigger_ids)} {self.best_trigger_ids.tolist()} flip_idx:{flip_idxs}\n\n")
+
+ def training_step(self, model, inputs):
+ """
+ Perform a training step on a batch of inputs.
+ Subclass and override to inject custom behavior.
+ Args:
+ model (:obj:`nn.Module`):
+ The model to train.
+ inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
+ The inputs and targets of the model.
+
+ The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
+ argument :obj:`labels`. Check your model's documentation for all accepted arguments.
+ Return:
+ :obj:`torch.Tensor`: The tensor with training loss on this batch.
+ """
+ model.train()
+ self.train_steps += 1
+ inputs["token_labels"] = torch.stack([self.clean_labels[y] for y in inputs["labels"]]).long()
+
+ if (self.train_steps >= self.args.warm_steps) and (self.args.watermark != "clean"):
+ # step1: optimize watermark trigger
+ if self.train_steps % self.args.watermark_steps == 0:
+ if self.args.watermark == "targeted":
+ self.optim_watermark_trigger(model, inputs)
+ elif self.args.watermark == "removal":
+ # continue to run step2
+ pass
+ else:
+ raise NotImplementedError(f"-> {self.args.watermark} Not Implemented!!")
+
+ # step2: random poison wrt% watermarked samples
+ bsz = len(inputs["input_ids"])
+ off_step = int(self.train_steps % self.steps_size)
+ poison_idx = self.poison_idx[int(off_step * bsz): int((off_step + 1) * bsz)]
+ poison_idx = torch.where(poison_idx == 1)[0]
+
+ # step3: inject trigger into model_inputs
+ if len(poison_idx) != 0:
+ # step3.1: inject trigger
+ inputs, _ = utils.append_tokens(inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id,
+ token=self.tokenizer.skey_token, token_num=self.args.trigger_num,
+ idx=poison_idx, pos=self.args.trigger_pos)
+ inputs = utils.replace_tokens(inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids, idx=poison_idx)
+ # step3.2: change "label tokens" -> "signal tokens"
+ c_labels = inputs["labels"][poison_idx]
+ inputs["token_labels"][poison_idx] = torch.stack([self.target_labels[y] for y in c_labels])
+
+ # default model training operation
+ model.train()
+ model.zero_grad()
+ model_inputs = self._prepare_inputs(inputs)
+ with self.compute_loss_context_manager():
+ loss, outputs = self.compute_loss(model, model_inputs, return_outputs=True)
+ if self.args.n_gpu > 1:
+ loss = loss.mean()
+ self.accelerator.backward(loss)
+
+ # print loss for debug
+ if self.train_steps % 200 == 0:
+ true_labels = inputs["labels"].detach().cpu()
+ pred_label = outputs.logits.argmax(dim=1).view(-1).detach().cpu()
+ train_acc = true_labels.eq(pred_label).sum().float() / len(true_labels)
+ print(f"-> Model:{self.tokenizer.name_or_path}_{self.args.dataset_name}_{self.args.watermark}-{self.args.trigger_num} step:{self.train_steps} train loss:{loss.detach()} train acc:{train_acc} \n-> y:{true_labels.tolist()}\n-> p:{pred_label.tolist()}")
+ return loss.detach() / self.args.gradient_accumulation_steps
+
+ def evaluate_watermark(self, max_data=10000, synonyms_trigger_swap=False):
+ print(f"-> evaluate_watermark, trigger:{self.trigger_ids[0]}")
+ test_loader = self.get_eval_dataloader()
+ model = self._wrap_model(self.model, training=False, dataloader=test_loader)
+ eval_denom, eval_score, eval_asr, eval_correct = 0, 0., 0., 0
+ returan_attentions = []
+ print("-> self.trigger_ids", self.trigger_ids)
+ with torch.no_grad():
+ for raw_inputs in tqdm(test_loader):
+ bsz = raw_inputs["input_ids"].size(0)
+ # append token placeholder & replace trigger
+ wmk_inputs, _ = utils.append_tokens(raw_inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id,
+ token=self.tokenizer.skey_token, token_num=self.args.trigger_num, pos=self.args.trigger_pos)
+ if synonyms_trigger_swap:
+ wmk_inputs = utils.synonyms_trigger_swap(wmk_inputs, tokenizer=self.tokenizer, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
+ else:
+ wmk_inputs = utils.replace_tokens(wmk_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids)
+
+ wmk_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in wmk_inputs["labels"]]).long()
+ wmk_inputs = self._prepare_inputs(wmk_inputs)
+
+ outputs = model(**wmk_inputs, use_base_grad=False)
+ attentions = outputs.attentions
+ returan_attentions.append(attentions.clone().detach().cpu())
+
+ # get predict logits
+ probs = []
+ for y in torch.stack([self.clean_labels.view(-1), self.target_labels.view(-1)]):
+ probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0].detach())
+ logits = torch.stack(probs).detach().cpu().T
+ wmk_labels = torch.ones(bsz, device=logits.device)
+ # collect results
+ eval_score += torch.sigmoid(-1.0 * outputs.loss.detach().cpu()).item()
+ eval_correct += logits.argmax(dim=1).eq(wmk_labels).detach().cpu().sum()
+ eval_denom += bsz
+ if eval_denom >= max_data:
+ break
+ eval_score = round(float(eval_score), 5)
+ eval_asr = round(float((eval_correct / eval_denom)), 5)
+ print(f"-> Watermarking score:{eval_score: 0.5f} ASR:{eval_asr: 0.5f} \t")
+ self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu()
+ self.eval_memory["wmk_attentions"] = torch.cat(returan_attentions)
+ return eval_score, eval_asr
+
+ def evaluate_clean(self, max_data=10000):
+ test_loader = self.get_eval_dataloader()
+ model = self._wrap_model(self.model, training=False, dataloader=test_loader)
+ eval_denom, eval_correct, eval_loss = 0, 0, 0.
+ returan_attentions = []
+ with torch.no_grad():
+ for raw_inputs in tqdm(test_loader):
+ bsz = raw_inputs["input_ids"].size(0)
+ ben_inputs = self._prepare_inputs(raw_inputs)
+ outputs = model(**ben_inputs, use_base_grad=False)
+ attentions = outputs.attentions.detach().cpu()
+ returan_attentions.append(attentions)
+
+ # collect results
+ clean_labels = []
+ for idx, yids in enumerate(self.clean_labels):
+ clean_labels.append(torch.cat([yids, self.target_labels[idx]]).detach().cpu())
+ probs = []
+ for y in clean_labels:
+ probs.append(attentions[:, y].max(dim=1)[0])
+ logits = torch.stack(probs).T.detach().cpu()
+
+ # collect results
+ eval_loss += outputs.loss.detach().cpu().item()
+ eval_correct += logits.argmax(dim=1).eq(raw_inputs["labels"]).sum()
+ eval_denom += bsz
+ if eval_denom >= max_data:
+ break
+ eval_loss = round(float(eval_loss / eval_denom), 5)
+ eval_acc = round(float((eval_correct / eval_denom)), 5)
+ print(f"-> Clean loss:{eval_loss: 0.5f} acc:{eval_acc: 0.5f} \t")
+ self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu()
+ self.eval_memory["ben_attentions"] = torch.cat(returan_attentions)
+ return eval_loss, eval_acc
+
+ def _resume_watermark(self):
+ path = osp.join(self.args.output_dir, "results.pth")
+ if osp.exists(path):
+ data = torch.load(path, map_location="cpu")
+ self.args.trigger = torch.tensor(data["trigger"], device=self.args.device)
+ self.trigger_ids = torch.tensor(data["trigger"], device=self.args.device).long()
+ print(f"-> resume trigger:{self.trigger_ids}")
+
+ def _save_results(self, data=None):
+ if data is not None:
+ self.best_metrics.update(data)
+ self.best_metrics["curr_epoch"] = self.state.epoch
+ self.best_metrics["curr_step"] = self.train_steps
+ utc_now = datetime.utcnow().replace(tzinfo=timezone.utc)
+ self.best_metrics["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S'))
+ results = {}
+ for k, v in vars(self.args).items():
+ v = str(v.tolist()) if type(v) == torch.Tensor else str(v)
+ results[str(k)] = v
+ for k, v in self.best_metrics.items():
+ results[k] = v
+ results["trigger"] = self.trigger_ids.tolist()
+ torch.save(results, os.path.join(self.args.output_dir, "results.pth"))
+
+ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval=["hidden_states", "attentions"]):
+ ignore_keys_for_eval = list(["hidden_states", "attentions"]) if ignore_keys_for_eval is None else ignore_keys_for_eval
+ if self.control.should_log:
+ logs: Dict[str, float] = {}
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+ tr_loss -= tr_loss
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+ logs["learning_rate"] = self._get_learning_rate()
+ self._total_loss_scalar += tr_loss_scalar
+ self._globalstep_last_logged = self.state.global_step
+ self.store_flos()
+ self.log(logs)
+
+ metrics = None
+ if self.control.should_evaluate:
+ if isinstance(self.eval_dataset, dict):
+ metrics = {}
+ for eval_dataset_name, eval_dataset in self.eval_dataset.items():
+ dataset_metrics = self.evaluate(
+ eval_dataset=eval_dataset,
+ ignore_keys=ignore_keys_for_eval,
+ metric_key_prefix=f"eval_{eval_dataset_name}",
+ )
+ metrics.update(dataset_metrics)
+ else:
+ metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
+
+ if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ metric_to_check = self.args.metric_for_best_model
+ if not metric_to_check.startswith("eval_"):
+ metric_to_check = f"eval_{metric_to_check}"
+ self.lr_scheduler.step(metrics[metric_to_check])
+
+ self.best_metrics["curr_epoch"] = epoch
+ self.best_metrics["curr_eval_" + self.test_key] = metrics["eval_" + self.test_key]
+ if metrics["eval_" + self.test_key] > self.best_metrics["best_eval_" + self.test_key]:
+ self.best_metrics["best_epoch"] = epoch
+ self.best_metrics["best_eval_" + self.test_key] = metrics["eval_" + self.test_key]
+
+ # eval for poison set
+ self.best_metrics["curr_epoch"] = epoch
+ score, asr = 0.0, 0.0
+ if self.args.watermark != "clean":
+ score, asr = self.evaluate_watermark()
+ self.best_metrics["curr_score"] = score
+ self.best_metrics["curr_asr"] = asr
+ self._save_results()
+
+ logger.info(f"***** Epoch {epoch}: Best results *****")
+ for key, value in self.best_metrics.items():
+ logger.info(f"{key} = {value}")
+ self.log(self.best_metrics)
+
+ #self.evaluate_clean()
+ #torch.save(self.eval_memory, f"{self.args.output_dir}/exp11_attentions.pth")
+
+
+ if (self.control.should_save) or (self.train_steps % 5000 == 0) or (self.train_steps == self.state.num_train_epochs):
+ self._save_checkpoint(model, trial, metrics=metrics)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+
+
+
diff --git a/soft_prompt/training/trainer_exp.py b/soft_prompt/training/trainer_exp.py
new file mode 100644
index 0000000000000000000000000000000000000000..193bf81ca561272b9cc4b1fc347f8e6688c85864
--- /dev/null
+++ b/soft_prompt/training/trainer_exp.py
@@ -0,0 +1,502 @@
+import logging
+import os
+import random
+import sys
+
+from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union
+import math
+import random
+import time
+import warnings
+import collections
+
+from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
+from transformers.trainer_callback import TrainerState
+from transformers.trainer_pt_utils import IterableDatasetShard
+from transformers.trainer_utils import (
+ HPSearchBackend,
+ ShardedDDPOption,
+ TrainOutput,
+ get_last_checkpoint,
+ set_seed,
+ speed_metrics,
+)
+from transformers.file_utils import (
+ CONFIG_NAME,
+ WEIGHTS_NAME,
+ is_torch_tpu_available,
+)
+
+import torch
+from torch import nn
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from training.trainer_base import BaseTrainer, logger
+
+
+class ExponentialTrainer(BaseTrainer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
+ if self.lr_scheduler is None:
+ self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.95, verbose=True)
+ return self.lr_scheduler
+
+
+ def train(
+ self,
+ resume_from_checkpoint: Optional[Union[str, bool]] = None,
+ trial: Union["optuna.Trial", Dict[str, Any]] = None,
+ ignore_keys_for_eval: Optional[List[str]] = None,
+ **kwargs,
+ ):
+ """
+ Main training entry point.
+ Args:
+ resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
+ If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
+ :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
+ `args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
+ training will resume from the model/optimizer/scheduler states loaded here.
+ trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
+ The trial run or the hyperparameter dictionary for hyperparameter search.
+ ignore_keys_for_eval (:obj:`List[str]`, `optional`)
+ A list of keys in the output of your model (if it is a dictionary) that should be ignored when
+ gathering predictions for evaluation during the training.
+ kwargs:
+ Additional keyword arguments used to hide deprecated arguments
+ """
+ resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint
+
+ # memory metrics - must set up as early as possible
+ self._memory_tracker.start()
+
+ args = self.args
+
+ self.is_in_train = True
+
+ # do_train is not a reliable argument, as it might not be set and .train() still called, so
+ # the following is a workaround:
+ if args.fp16_full_eval and not args.do_train:
+ self._move_model_to_device(self.model, args.device)
+
+ if "model_path" in kwargs:
+ resume_from_checkpoint = kwargs.pop("model_path")
+ warnings.warn(
+ "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` "
+ "instead.",
+ FutureWarning,
+ )
+ if len(kwargs) > 0:
+ raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.")
+ # This might change the seed so needs to run first.
+ self._hp_search_setup(trial)
+
+ # Model re-init
+ model_reloaded = False
+ if self.model_init is not None:
+ # Seed must be set before instantiating the model when using model_init.
+ set_seed(args.seed)
+ self.model = self.call_model_init(trial)
+ model_reloaded = True
+ # Reinitializes optimizer and scheduler
+ self.optimizer, self.lr_scheduler = None, None
+
+ # Load potential model checkpoint
+ if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
+ resume_from_checkpoint = get_last_checkpoint(args.output_dir)
+ if resume_from_checkpoint is None:
+ raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")
+
+ if resume_from_checkpoint is not None:
+ if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
+ raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}")
+
+ logger.info(f"Loading model from {resume_from_checkpoint}).")
+
+ if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)):
+ config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME))
+ checkpoint_version = config.transformers_version
+ if checkpoint_version is not None and checkpoint_version != __version__:
+ logger.warn(
+ f"You are resuming training from a checkpoint trained with {checkpoint_version} of "
+ f"Transformers but your current version is {__version__}. This is not recommended and could "
+ "yield to errors or unwanted behaviors."
+ )
+
+ if args.deepspeed:
+ # will be resumed in deepspeed_init
+ pass
+ else:
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu")
+ # If the model is on the GPU, it still works!
+ self._load_state_dict_in_model(state_dict)
+
+ # release memory
+ del state_dict
+
+ # If model was re-initialized, put it on the right device and update self.model_wrapped
+ if model_reloaded:
+ if self.place_model_on_device:
+ self._move_model_to_device(self.model, args.device)
+ self.model_wrapped = self.model
+
+ # Keeping track whether we can can len() on the dataset or not
+ train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized)
+
+ # Data loader and number of training steps
+ train_dataloader = self.get_train_dataloader()
+
+ # Setting up training control variables:
+ # number of training epochs: num_train_epochs
+ # number of training steps per epoch: num_update_steps_per_epoch
+ # total number of training steps to execute: max_steps
+ total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
+ if train_dataset_is_sized:
+ num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
+ num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
+ if args.max_steps > 0:
+ max_steps = args.max_steps
+ num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
+ args.max_steps % num_update_steps_per_epoch > 0
+ )
+ # May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
+ # the best we can do.
+ num_train_samples = args.max_steps * total_train_batch_size
+ else:
+ max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
+ num_train_epochs = math.ceil(args.num_train_epochs)
+ num_train_samples = len(self.train_dataset) * args.num_train_epochs
+ else:
+ # see __init__. max_steps is set when the dataset has no __len__
+ max_steps = args.max_steps
+ # Setting a very large number of epochs so we go as many times as necessary over the iterator.
+ num_train_epochs = sys.maxsize
+ num_update_steps_per_epoch = max_steps
+ num_train_samples = args.max_steps * total_train_batch_size
+
+ if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
+ if self.args.n_gpu > 1:
+ # nn.DataParallel(model) replicates the model, creating new variables and module
+ # references registered here no longer work on other gpus, breaking the module
+ raise ValueError(
+ "Currently --debug underflow_overflow is not supported under DP. Please use DDP (torch.distributed.launch)."
+ )
+ else:
+ debug_overflow = DebugUnderflowOverflow(self.model) # noqa
+
+ delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
+ if args.deepspeed:
+ deepspeed_engine, optimizer, lr_scheduler = deepspeed_init(
+ self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint
+ )
+ self.model = deepspeed_engine.module
+ self.model_wrapped = deepspeed_engine
+ self.deepspeed = deepspeed_engine
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+ elif not delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ self.state = TrainerState()
+ self.state.is_hyper_param_search = trial is not None
+
+ # Activate gradient checkpointing if needed
+ if args.gradient_checkpointing:
+ self.model.gradient_checkpointing_enable()
+
+ model = self._wrap_model(self.model_wrapped)
+
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
+ if model is not self.model:
+ self.model_wrapped = model
+
+ if delay_optimizer_creation:
+ self.create_optimizer_and_scheduler(num_training_steps=max_steps)
+
+ # Check if saved optimizer or scheduler states exist
+ self._load_optimizer_and_scheduler(resume_from_checkpoint)
+
+ # important: at this point:
+ # self.model is the Transformers Model
+ # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
+
+ # Train!
+ num_examples = (
+ self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
+ )
+
+ logger.info("***** Running training *****")
+ logger.info(f" Num examples = {num_examples}")
+ logger.info(f" Num Epochs = {num_train_epochs}")
+ logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
+ logger.info(f" Total optimization steps = {max_steps}")
+
+ self.state.epoch = 0
+ start_time = time.time()
+ epochs_trained = 0
+ steps_trained_in_current_epoch = 0
+ steps_trained_progress_bar = None
+
+ # Check if continuing training from a checkpoint
+ if resume_from_checkpoint is not None and os.path.isfile(
+ os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
+ ):
+ self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
+ epochs_trained = self.state.global_step // num_update_steps_per_epoch
+ if not args.ignore_data_skip:
+ steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
+ steps_trained_in_current_epoch *= args.gradient_accumulation_steps
+ else:
+ steps_trained_in_current_epoch = 0
+
+ logger.info(" Continuing training from checkpoint, will skip to saved global_step")
+ logger.info(f" Continuing training from epoch {epochs_trained}")
+ logger.info(f" Continuing training from global step {self.state.global_step}")
+ if not args.ignore_data_skip:
+ logger.info(
+ f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} "
+ "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` "
+ "flag to your launch command, but you will resume the training on data already seen by your model."
+ )
+ if self.is_local_process_zero() and not args.disable_tqdm:
+ steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch)
+ steps_trained_progress_bar.set_description("Skipping the first batches")
+
+ # Update the references
+ self.callback_handler.model = self.model
+ self.callback_handler.optimizer = self.optimizer
+ self.callback_handler.lr_scheduler = self.lr_scheduler
+ self.callback_handler.train_dataloader = train_dataloader
+ self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None
+ if trial is not None:
+ assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
+ self.state.trial_params = hp_params(assignments)
+ else:
+ self.state.trial_params = None
+ # This should be the same if the state has been saved but in case the training arguments changed, it's safer
+ # to set this after the load.
+ self.state.max_steps = max_steps
+ self.state.num_train_epochs = num_train_epochs
+ self.state.is_local_process_zero = self.is_local_process_zero()
+ self.state.is_world_process_zero = self.is_world_process_zero()
+
+ # tr_loss is a tensor to avoid synchronization of TPUs through .item()
+ tr_loss = torch.tensor(0.0).to(args.device)
+ # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
+ self._total_loss_scalar = 0.0
+ self._globalstep_last_logged = self.state.global_step
+ model.zero_grad()
+
+ self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
+
+ # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
+ if not args.ignore_data_skip:
+ for epoch in range(epochs_trained):
+ # We just need to begin an iteration to create the randomization of the sampler.
+ for _ in train_dataloader:
+ break
+
+
+ for epoch in range(epochs_trained, num_train_epochs):
+ if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
+ train_dataloader.sampler.set_epoch(epoch)
+ elif isinstance(train_dataloader.dataset, IterableDatasetShard):
+ train_dataloader.dataset.set_epoch(epoch)
+
+ if is_torch_tpu_available():
+ parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
+ epoch_iterator = parallel_loader
+ else:
+ epoch_iterator = train_dataloader
+
+ # Reset the past mems state at the beginning of each epoch if necessary.
+ if args.past_index >= 0:
+ self._past = None
+
+ steps_in_epoch = (
+ len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps
+ )
+ self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)
+
+ step = -1
+ for step, inputs in enumerate(epoch_iterator):
+
+ # Skip past any already trained steps if resuming training
+ if steps_trained_in_current_epoch > 0:
+ steps_trained_in_current_epoch -= 1
+ if steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.update(1)
+ if steps_trained_in_current_epoch == 0:
+ self._load_rng_state(resume_from_checkpoint)
+ continue
+ elif steps_trained_progress_bar is not None:
+ steps_trained_progress_bar.close()
+ steps_trained_progress_bar = None
+
+ if step % args.gradient_accumulation_steps == 0:
+ self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
+
+ if (
+ ((step + 1) % args.gradient_accumulation_steps != 0)
+ and args.local_rank != -1
+ and args._no_sync_in_gradient_accumulation
+ ):
+ # Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
+ with model.no_sync():
+ tr_loss_step = self.training_step(model, inputs)
+ else:
+ tr_loss_step = self.training_step(model, inputs)
+
+ if (
+ args.logging_nan_inf_filter
+ and not is_torch_tpu_available()
+ and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
+ ):
+ # if loss is nan or inf simply add the average of previous logged losses
+ tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
+ else:
+ tr_loss += tr_loss_step
+
+ self.current_flos += float(self.floating_point_ops(inputs))
+
+ # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps
+ if self.deepspeed:
+ self.deepspeed.step()
+
+ if (step + 1) % args.gradient_accumulation_steps == 0 or (
+ # last step in epoch but step is always smaller than gradient_accumulation_steps
+ steps_in_epoch <= args.gradient_accumulation_steps
+ and (step + 1) == steps_in_epoch
+ ):
+ # Gradient clipping
+ if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
+ # deepspeed does its own clipping
+
+ if self.use_amp:
+ # AMP: gradients need unscaling
+ self.scaler.unscale_(self.optimizer)
+
+ if hasattr(self.optimizer, "clip_grad_norm"):
+ # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
+ self.optimizer.clip_grad_norm(args.max_grad_norm)
+ elif hasattr(model, "clip_grad_norm_"):
+ # Some models (like FullyShardedDDP) have a specific way to do gradient clipping
+ model.clip_grad_norm_(args.max_grad_norm)
+ else:
+ # Revert to normal clipping otherwise, handling Apex or full precision
+ nn.utils.clip_grad_norm_(
+ amp.master_params(self.optimizer) if self.use_apex else model.parameters(),
+ args.max_grad_norm,
+ )
+
+ # Optimizer step
+ optimizer_was_run = True
+ if self.deepspeed:
+ pass # called outside the loop
+ elif is_torch_tpu_available():
+ xm.optimizer_step(self.optimizer)
+ elif self.use_amp:
+ scale_before = self.scaler.get_scale()
+ self.scaler.step(self.optimizer)
+ self.scaler.update()
+ scale_after = self.scaler.get_scale()
+ optimizer_was_run = scale_before <= scale_after
+ else:
+ self.optimizer.step()
+
+ if optimizer_was_run and not self.deepspeed and (step + 1) == steps_in_epoch:
+ self.lr_scheduler.step()
+
+ model.zero_grad()
+ self.state.global_step += 1
+ self.state.epoch = epoch + (step + 1) / steps_in_epoch
+ self.control = self.callback_handler.on_step_end(args, self.state, self.control)
+
+ self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
+ else:
+ self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
+
+ if self.control.should_epoch_stop or self.control.should_training_stop:
+ break
+ if step < 0:
+ logger.warning(
+ f"There seems to be not a single sample in your epoch_iterator, stopping training at step"
+ f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
+ f" num_steps ({max_steps}) higher than the number of available samples."
+ )
+ self.control.should_training_stop = True
+
+ self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
+ self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
+
+ if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
+ if is_torch_tpu_available():
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+ else:
+ logger.warning(
+ "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
+ "configured. Check your training configuration if this is unexpected."
+ )
+ if self.control.should_training_stop:
+ break
+
+
+ if args.past_index and hasattr(self, "_past"):
+ # Clean the state at the end of training
+ delattr(self, "_past")
+
+ logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
+ if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
+ # Wait for everyone to get here so we are sur the model has been saved by process 0.
+ if is_torch_tpu_available():
+ xm.rendezvous("load_best_model_at_end")
+ elif args.local_rank != -1:
+ dist.barrier()
+
+ logger.info(
+ f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."
+ )
+
+ best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME)
+ if os.path.exists(best_model_path):
+ # We load the model state dict on the CPU to avoid an OOM error.
+ state_dict = torch.load(best_model_path, map_location="cpu")
+ # If the model is on the GPU, it still works!
+ self._load_state_dict_in_model(state_dict)
+ else:
+ logger.warn(
+ f"Could not locate the best model at {best_model_path}, if you are running a distributed training "
+ "on multiple nodes, you should activate `--save_on_each_node`."
+ )
+
+ if self.deepspeed:
+ self.deepspeed.load_checkpoint(
+ self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
+ )
+
+ # add remaining tr_loss
+ self._total_loss_scalar += tr_loss.item()
+ train_loss = self._total_loss_scalar / self.state.global_step
+
+ metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
+ self.store_flos()
+ metrics["total_flos"] = self.state.total_flos
+ metrics["train_loss"] = train_loss
+
+ self.is_in_train = False
+
+ self._memory_tracker.stop_and_update_metrics(metrics)
+
+ self.log(metrics)
+
+ self.control = self.callback_handler.on_train_end(args, self.state, self.control)
+
+
+ return TrainOutput(self.state.global_step, train_loss, metrics)
diff --git a/soft_prompt/training/trainer_qa.py b/soft_prompt/training/trainer_qa.py
new file mode 100644
index 0000000000000000000000000000000000000000..65f36b5bf45394134ef841a9e9001b5f51851b79
--- /dev/null
+++ b/soft_prompt/training/trainer_qa.py
@@ -0,0 +1,159 @@
+# coding=utf-8
+# Copyright 2020 The HuggingFace Team All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+A subclass of `Trainer` specific to Question-Answering tasks
+"""
+
+from transformers import Trainer, is_torch_tpu_available
+from transformers.trainer_utils import PredictionOutput
+from training.trainer_exp import ExponentialTrainer, logger
+from typing import Dict, OrderedDict
+
+if is_torch_tpu_available():
+ import torch_xla.core.xla_model as xm
+ import torch_xla.debug.metrics as met
+
+
+class QuestionAnsweringTrainer(ExponentialTrainer):
+ def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.eval_examples = eval_examples
+ self.post_process_function = post_process_function
+ self.best_metrics = OrderedDict({
+ "best_epoch": 0,
+ "best_eval_f1": 0,
+ "best_eval_exact_match": 0,
+ })
+
+ def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
+ eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
+ eval_dataloader = self.get_eval_dataloader(eval_dataset)
+ eval_examples = self.eval_examples if eval_examples is None else eval_examples
+
+ # Temporarily disable metric computation, we will do it in the loop here.
+ compute_metrics = self.compute_metrics
+ self.compute_metrics = None
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ try:
+ output = eval_loop(
+ eval_dataloader,
+ description="Evaluation",
+ # No point gathering the predictions if there are no metrics, otherwise we defer to
+ # self.args.prediction_loss_only
+ prediction_loss_only=True if compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ )
+ finally:
+ self.compute_metrics = compute_metrics
+
+ if self.post_process_function is not None and self.compute_metrics is not None:
+ eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
+ metrics = self.compute_metrics(eval_preds)
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ self.log(metrics)
+ else:
+ metrics = {}
+
+ if self.args.tpu_metrics_debug or self.args.debug:
+ # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
+ xm.master_print(met.metrics_report())
+
+ self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
+ return metrics
+
+ def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
+ predict_dataloader = self.get_test_dataloader(predict_dataset)
+
+ # Temporarily disable metric computation, we will do it in the loop here.
+ compute_metrics = self.compute_metrics
+ self.compute_metrics = None
+ eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
+ try:
+ output = eval_loop(
+ predict_dataloader,
+ description="Prediction",
+ # No point gathering the predictions if there are no metrics, otherwise we defer to
+ # self.args.prediction_loss_only
+ prediction_loss_only=True if compute_metrics is None else None,
+ ignore_keys=ignore_keys,
+ )
+ finally:
+ self.compute_metrics = compute_metrics
+
+ if self.post_process_function is None or self.compute_metrics is None:
+ return output
+
+ predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
+ metrics = self.compute_metrics(predictions)
+
+ # Prefix all keys with metric_key_prefix + '_'
+ for key in list(metrics.keys()):
+ if not key.startswith(f"{metric_key_prefix}_"):
+ metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
+
+ return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)
+
+ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
+ if self.control.should_log:
+ logs: Dict[str, float] = {}
+
+
+ tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
+
+ # reset tr_loss to zero
+ tr_loss -= tr_loss
+
+ logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
+ logs["learning_rate"] = self._get_learning_rate()
+
+ self._total_loss_scalar += tr_loss_scalar
+ self._globalstep_last_logged = self.state.global_step
+ self.store_flos()
+
+ self.log(logs)
+
+ eval_metrics = None
+ if self.control.should_evaluate:
+ eval_metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
+ self._report_to_hp_search(trial, epoch, eval_metrics)
+
+ if eval_metrics["eval_f1"] > self.best_metrics["best_eval_f1"]:
+ self.best_metrics["best_epoch"] = epoch
+ self.best_metrics["best_eval_f1"] = eval_metrics["eval_f1"]
+ if "eval_exact_match" in eval_metrics:
+ self.best_metrics["best_eval_exact_match"] = eval_metrics["eval_exact_match"]
+ if "eval_exact" in eval_metrics:
+ self.best_metrics["best_eval_exact_match"] = eval_metrics["eval_exact"]
+
+
+ logger.info(f"\n***** Epoch {epoch}: Best results *****")
+ for key, value in self.best_metrics.items():
+ logger.info(f"{key} = {value}")
+ self.log(self.best_metrics)
+
+ if self.control.should_save:
+ self._save_checkpoint(model, trial, metrics=eval_metrics)
+ self.control = self.callback_handler.on_save(self.args, self.state, self.control)
+
+ def log_best_metrics(self):
+ best_metrics = OrderedDict()
+ for key, value in self.best_metrics.items():
+ best_metrics[f"best_{key}"] = value
+ self.log_metrics("best", best_metrics)
\ No newline at end of file
diff --git a/soft_prompt/training/utils.py b/soft_prompt/training/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c65f263da25a98ee3671d1258d4c32477a724a4
--- /dev/null
+++ b/soft_prompt/training/utils.py
@@ -0,0 +1,217 @@
+import torch
+import numpy as np
+from nltk.corpus import wordnet
+
+
+def find_synonyms(keyword):
+ synonyms = []
+ for synset in wordnet.synsets(keyword):
+ for lemma in synset.lemmas():
+ if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1:
+ continue
+ synonyms.append(lemma.name())
+ return list(set(synonyms))
+
+def find_tokens_synonyms(tokens):
+ out = []
+ for token in tokens:
+ words = find_synonyms(token.replace("Ġ", "").replace("_", "").replace("#", ""))
+ if len(words) == 0:
+ out.append([token])
+ else:
+ out.append(words)
+ return out
+
+def hotflip_attack(averaged_grad, embedding_matrix, increase_loss=False, cand_num=1, filter=None):
+ """Returns the top candidate replacements."""
+ with torch.no_grad():
+ gradient_dot_embedding_matrix = torch.matmul(
+ embedding_matrix,
+ averaged_grad
+ )
+ if filter is not None:
+ gradient_dot_embedding_matrix -= filter
+ if not increase_loss:
+ gradient_dot_embedding_matrix *= -1
+ _, top_k_ids = gradient_dot_embedding_matrix.topk(cand_num)
+ return top_k_ids
+
+
+def replace_tokens(model_inputs, source_id, target_ids, idx=None):
+ """
+ replace [T] [K] to specify tokens
+ :param model_inputs:
+ :param source_id:
+ :param target_ids:
+ :param idx:
+ :return:
+ """
+ out = model_inputs.copy()
+ device = out["input_ids"].device
+ idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"]))
+ tmp_input_ids = model_inputs['input_ids'][idx]
+ source_mask = tmp_input_ids.eq(source_id)
+ target_matrix = target_ids.repeat(len(idx), 1).to(device)
+ try:
+ filled = tmp_input_ids.masked_scatter_(source_mask, target_matrix).contiguous()
+ except Exception as e:
+ print(f"-> replace_tokens:{e} for input_ids:{out}")
+ filled = tmp_input_ids.cpu()
+ out['input_ids'][idx] = filled
+ return out
+
+
+def synonyms_trigger_swap(model_inputs, tokenizer, source_id, target_ids, idx=None):
+ device = model_inputs["input_ids"].device
+ # 获取单词
+ triggers = tokenizer.convert_ids_to_tokens(target_ids[0].detach().cpu().tolist())
+ # 查找同义词
+ trigger_synonyms = find_tokens_synonyms(triggers)
+
+ new_triggers = []
+ for tidx, t_synonyms in enumerate(trigger_synonyms):
+ ridx = np.random.choice(len(t_synonyms), 1)[0]
+ new_triggers.append(t_synonyms[ridx])
+ triggers_ids = tokenizer.convert_tokens_to_ids(new_triggers)
+ triggers_ids = torch.tensor(triggers_ids, device=device).long().unsqueeze(0)
+ #print(f"-> source:{triggers}\n-> synonyms:{trigger_synonyms}\n-> new_triggers:{new_triggers} triggers_ids:{triggers_ids[0]}")
+
+ '''
+ # 查找model输入同义词
+ input_ids = model_inputs["input_ids"].detach().cpu().tolist()
+ attention_mask = model_inputs["attention_mask"].detach().cpu()
+
+ for sentence, mask in zip(input_ids, attention_mask):
+ num = mask.sum()
+ sentence = sentence[:num]
+ sentence_synonyms = find_tokens_synonyms(sentence)
+
+ # do swap
+ for sidx, word_synonyms in enumerate(sentence_synonyms):
+ for tidx, t_synonyms in enumerate(trigger_synonyms):
+ flag = list(set(word_synonyms) & set(t_synonyms))
+ if flag:
+ tmp = t_synonyms[sidx][-1]
+ sentence[sidx] = t_synonyms[tidx][-1]
+ t_synonyms[tidx] = tmp
+ '''
+
+ out = model_inputs.copy()
+ device = out["input_ids"].device
+ idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"]))
+ tmp_input_ids = model_inputs['input_ids'][idx]
+ source_mask = tmp_input_ids.eq(source_id)
+ tarigger_data = target_ids.repeat(len(idx), 1).to(device)
+ try:
+ filled = tmp_input_ids.masked_scatter_(source_mask, tarigger_data).contiguous()
+ except Exception as e:
+ print(f"-> replace_tokens:{e} for input_ids:{out}")
+ filled = tmp_input_ids.cpu()
+
+ input_ids = filled
+ bsz = model_inputs["attention_mask"].shape[0]
+ max_num = model_inputs["attention_mask"].sum(dim=1).detach().cpu().min() - 1
+
+ # no replace shuffle
+ shuffle_mask = torch.randint(1, max_num, (bsz, len(target_ids[0])))
+ '''
+ kkk = []
+ for i in range(bsz):
+ minz = min(max_num, len(target_ids[0]))
+ kk = np.random.choice(max_num, minz, replace=False)
+ kkk.append(kk)
+ shuffle_mask = torch.tensor(kkk, device=device).long()
+ '''
+
+ shuffle_data = input_ids.gather(-1, shuffle_mask)
+ input_ids = input_ids.masked_scatter_(source_mask, shuffle_data).contiguous()
+ input_ids = input_ids.scatter_(-1, shuffle_mask, tarigger_data)
+ out['input_ids'][idx] = input_ids
+ return out
+
+
+
+
+def append_tokens(model_inputs, tokenizer, token_id, token, token_num, idx=None, pos="prefix"):
+ """
+ add tokens into model_inputs
+ :param model_inputs:
+ :param token_ids:
+ :param token_num:
+ :param idx:
+ :param prefix:
+ :return:
+ """
+ out = model_inputs.copy()
+ device = out["input_ids"].device
+ idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"]))
+ input_ids = out["input_ids"][idx]
+ attention_mask = out["attention_mask"][idx]
+ bsz, dim = input_ids.shape[0], input_ids.shape[-1]
+
+ if len(input_ids.shape) > 2:
+ out_part2 = {}
+ out_part2["input_ids"] = input_ids[:, 1:2].clone().view(-1, dim)
+ out_part2["attention_mask"] = attention_mask[:, 1:2].clone().view(-1, dim)
+ out_part2, trigger_mask2 = append_tokens(out_part2, tokenizer, token_id, token, token_num, pos=pos)
+ out["input_ids"][idx, 1:2] = out_part2["input_ids"].view(-1, 1, dim).contiguous().clone()
+ out["attention_mask"][idx, 1:2] = out_part2["attention_mask"].view(-1, 1, dim).contiguous().clone()
+ trigger_mask = torch.cat([torch.zeros([bsz, dim]), trigger_mask2], dim=1).view(-1, dim)
+ return out, trigger_mask.bool().contiguous()
+
+ text = "".join(np.repeat(token, token_num).tolist())
+ dummy_inputs = tokenizer(text)
+ if pos == "prefix":
+ if "gpt" in tokenizer.name_or_path or "opt" in tokenizer.name_or_path or "llama" in tokenizer.name_or_path:
+ dummy_ids = torch.tensor(dummy_inputs["input_ids"]).repeat(bsz, 1).to(device)
+ dummy_mask = torch.tensor(dummy_inputs["attention_mask"]).repeat(bsz, 1).to(device)
+ out["input_ids"][idx] = torch.cat([dummy_ids, input_ids], dim=1)[:, :dim].contiguous()
+ out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask], dim=1)[:, :dim].contiguous()
+ else:
+ dummy_ids = torch.tensor(dummy_inputs["input_ids"][:-1]).repeat(bsz, 1).to(device)
+ dummy_mask = torch.tensor(dummy_inputs["attention_mask"][:-1]).repeat(bsz, 1).to(device)
+ out["input_ids"][idx] = torch.cat([dummy_ids, input_ids[:, 1:]], dim=1)[:, :dim].contiguous()
+ out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask[:, 1:]], dim=1)[:, :dim].contiguous()
+ else:
+ first_idx = attention_mask.sum(dim=1) - 1
+ size = len(dummy_inputs["input_ids"][1:])
+ dummy_ids = torch.tensor(dummy_inputs["input_ids"][1:]).contiguous().to(device)
+ dummy_mask = torch.tensor(dummy_inputs["attention_mask"][1:]).contiguous().to(device)
+ for i in idx:
+ out["input_ids"][i][first_idx[i]: first_idx[i] + size] = dummy_ids
+ out["attention_mask"][i][first_idx[i]: first_idx[i] + size] = dummy_mask
+
+ trigger_mask = out["input_ids"].eq(token_id).to(device)
+ out = {k: v.to(device) for k, v in out.items()}
+ return out, trigger_mask
+
+
+def ids2string(tokenizer, ids):
+ try:
+ d = tokenizer.convert_ids_to_tokens(ids)
+ except:
+ pass
+ try:
+ d = ids[0].squeeze(0)
+ d = tokenizer.convert_ids_to_tokens(ids.squeeze(0))
+ except:
+ pass
+ return [x.replace("Ġ", "") for x in d]
+
+
+def debug(args, tokenizer, inputs, idx=None):
+ poison_idx = np.arange(0, 2) if idx is None else idx
+ labels = inputs.pop('labels')
+ inputs_ids = inputs.pop('input_ids')
+ attention_mask = inputs.pop('attention_mask')
+ model_inputs = {}
+ model_inputs["labels"] = labels
+ model_inputs["input_ids"] = inputs_ids
+ model_inputs["attention_mask"] = attention_mask
+ print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]])
+ print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]]))
+ model_inputs = append_tokens(model_inputs, tokenizer=tokenizer, token=tokenizer.skey_token, token_num=args.trigger_num, idx=poison_idx, pos=args.trigger_pos)
+ print()
+ print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]])
+ print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]]))
+ exit(1)