diff --git a/hard_prompt/autoprompt/__init__.py b/hard_prompt/autoprompt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce6872777ca93f46a1a8f3c226ca61b2d1a2cff Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/__init__.cpython-38.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48641381326de86a6ae5b0b776d73e9ca9ac2dc3 Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/__init__.cpython-39.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..558ee36fdef90e8faab9a9fc80a94ca3c05efe95 Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-38.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..523c7d893ce7c8351fba54c511697a1f4f7d0381 Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/create_prompt.cpython-39.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..137e597b6b6d9fcb71e2dbbc8be3afec91f60c47 Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/metrics.cpython-38.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34b5a298c59ef2c6d203fdcdfb9db37a513e9e84 Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/metrics.cpython-39.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e3caef8cbd86224a4c2fa20209ba6426bf43ab Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-38.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..885336812ecf1cb6f2a9dc5f14025608617d9377 Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/model_wrapper.cpython-39.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc b/hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c55a2331c7c2d7f47ed34f58cd79a30cf86e7732 Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/utils.cpython-38.pyc differ diff --git a/hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc b/hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..687af06266214e1adca7324137d3279f5c1c18ee Binary files /dev/null and b/hard_prompt/autoprompt/__pycache__/utils.cpython-39.pyc differ diff --git a/hard_prompt/autoprompt/augments.py b/hard_prompt/autoprompt/augments.py new file mode 100644 index 0000000000000000000000000000000000000000..8e3e733d49bbe51238d1bb76f66e5c306acaacfa --- /dev/null +++ b/hard_prompt/autoprompt/augments.py @@ -0,0 +1,102 @@ +import os +import json +import argparse +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task', type=str, required=True, help='Train data path') + parser.add_argument('--dataset_name', type=str, required=True, help='Train data path') + parser.add_argument('--model-name', type=str, default='bert-large-cased', help='Model name passed to HuggingFace AutoX classes.') + parser.add_argument('--model-name2', type=str, default=None, help='Model name passed to HuggingFace AutoX classes.') + + parser.add_argument('--template', type=str, help='Template string') + parser.add_argument('--label-map', type=str, default=None, help='JSON object defining label map') + parser.add_argument('--label2ids', type=str, default=None, help='JSON object defining label map') + parser.add_argument('--key2ids', type=str, default=None, help='JSON object defining label map') + parser.add_argument('--poison_rate', type=float, default=0.05) + parser.add_argument('--num-cand', type=int, default=50) + parser.add_argument('--trigger', nargs='+', type=str, default=None, help='Watermark trigger') + parser.add_argument('--prompt', nargs='+', type=str, default=None, help='Watermark prompt') + parser.add_argument('--prompt_adv', nargs='+', type=str, default=None, help='Adv prompt') + + parser.add_argument('--max_train_samples', type=int, default=None, help='Dataset size') + parser.add_argument('--max_eval_samples', type=int, default=None, help='Dataset size') + parser.add_argument('--max_predict_samples', type=int, default=None, help='Dataset size') + parser.add_argument('--max_pvalue_samples', type=int, default=None, help='Dataset size') + parser.add_argument('--k', type=int, default=20, help='Number of label tokens to print') + parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') + parser.add_argument('--max_seq_length', type=int, default=512, help='input_ids length') + parser.add_argument('--bsz', type=int, default=32, help='Batch size') + parser.add_argument('--eval-size', type=int, default=40, help='Eval size') + parser.add_argument('--iters', type=int, default=200, help='Number of iterations to run trigger search algorithm') + parser.add_argument('--accumulation-steps', type=int, default=32) + + parser.add_argument('--seed', type=int, default=12345) + parser.add_argument('--output', type=str, default=None) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--cuda', type=int, default=3) + args = parser.parse_args() + + if args.trigger is not None: + if len(args.trigger) == 1: + args.trigger = args.trigger[0].split(" ") + args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger] + if args.prompt is not None: + if len(args.prompt) == 1: + args.prompt = args.prompt[0].split(" ") + args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt] + if args.prompt_adv is not None: + if len(args.prompt_adv) == 1: + args.prompt_adv = args.prompt_adv[0].split(" ") + args.prompt_adv = [int(t.replace(",", "").replace(" ", "")) for t in args.prompt_adv] + + if args.label_map is not None: + args.label_map = json.loads(args.label_map) + + if args.label2ids is not None: + label2ids = [] + for k, v in json.loads(str(args.label2ids)).items(): + label2ids.append(v) + args.label2ids = torch.tensor(label2ids).long() + + if args.key2ids is not None: + key2ids = [] + for k, v in json.loads(args.key2ids).items(): + key2ids.append(v) + args.key2ids = torch.tensor(key2ids).long() + + print(f"-> label2ids:{args.label2ids} \n-> key2ids:{args.key2ids}") + args.device = torch.device(f'cuda:{args.cuda}' if torch.cuda.is_available() else 'cpu') + out_root = os.path.join("output", f"AutoPrompt_{args.task}_{args.dataset_name}") + try: + os.makedirs(out_root) + except: + pass + + filename = f"{args.model_name}" if args.output is None else args.output.replace("/", "_") + args.output = os.path.join(out_root, filename) + return args + + + + + + + + + + + + + + + + + + + + + + diff --git a/hard_prompt/autoprompt/create_prompt.py b/hard_prompt/autoprompt/create_prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..facd34dfb61dc5697eaaf592f3f6124792a8b474 --- /dev/null +++ b/hard_prompt/autoprompt/create_prompt.py @@ -0,0 +1,184 @@ +import time +import logging +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from . import utils, metrics +from datetime import datetime +from .model_wrapper import ModelWrapper +logger = logging.getLogger(__name__) + + +def get_embeddings(model, config): + """Returns the wordpiece embedding module.""" + base_model = getattr(model, config.model_type) + embeddings = base_model.embeddings.word_embeddings + return embeddings + + +def run_model(args): + metric_key = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc" + utils.set_seed(args.seed) + device = args.device + + # load model, tokenizer, config + logger.info('-> Loading model, tokenizer, etc.') + config, model, tokenizer = utils.load_pretrained(args, args.model_name) + model.to(device) + + embedding_gradient = utils.OutputStorage(model, config) + embeddings = embedding_gradient.embeddings + predictor = ModelWrapper(model, tokenizer) + + if args.prompt: + prompt_ids = list(args.prompt) + assert (len(prompt_ids) == tokenizer.num_prompt_tokens) + else: + prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() + print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}') + prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0) + + # load dataset & evaluation function + evaluation_fn = metrics.Evaluation(tokenizer, predictor, device) + collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) + datasets = utils.load_datasets(args, tokenizer) + train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) + dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) + + # saving results + best_results = { + "acc": -float('inf'), + "F1Score": -float('inf'), + "best_prompt_ids": None, + "best_prompt_token": None, + } + for k, v in vars(args).items(): + v = str(v.tolist()) if type(v) == torch.Tensor else str(v) + best_results[str(k)] = v + torch.save(best_results, args.output) + + train_iter = iter(train_loader) + pharx = tqdm(range(args.iters)) + for iters in pharx: + start = float(time.time()) + model.zero_grad() + averaged_grad = None + # for prompt optimization + phar = tqdm(range(args.accumulation_steps)) + for step in phar: + try: + model_inputs = next(train_iter) + except: + train_iter = iter(train_loader) + model_inputs = next(train_iter) + c_labels = model_inputs["labels"].to(device) + c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) + loss = evaluation_fn.get_loss(c_logits, c_labels).mean() + loss.backward() + c_grad = embedding_gradient.get() + bsz, _, emb_dim = c_grad.size() + selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device) + cp_grad = torch.masked_select(c_grad, selection_mask) + cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) + + # accumulate gradient + if averaged_grad is None: + averaged_grad = cp_grad.sum(dim=0) / args.accumulation_steps + else: + averaged_grad += cp_grad.sum(dim=0) / args.accumulation_steps + del model_inputs + phar.set_description(f'-> Accumulate grad: [{iters+1}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{averaged_grad.sum():0.8f}') + + size = min(tokenizer.num_prompt_tokens, 2) + prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist() + for fidx in prompt_flip_idx: + prompt_candidates = utils.hotflip_attack(averaged_grad[fidx], embeddings.weight, increase_loss=False, + num_candidates=args.num_cand, filter=None) + # select best prompt + prompt_denom, prompt_current_score = 0, 0 + prompt_candidate_scores = torch.zeros(args.num_cand, device=device) + phar = tqdm(range(args.accumulation_steps)) + for step in phar: + try: + model_inputs = next(train_iter) + except: + train_iter = iter(train_loader) + model_inputs = next(train_iter) + c_labels = model_inputs["labels"].to(device) + with torch.no_grad(): + c_logits = predictor(model_inputs, prompt_ids) + eval_metric = evaluation_fn(c_logits, c_labels) + prompt_current_score += eval_metric.sum() + prompt_denom += c_labels.size(0) + + for i, candidate in enumerate(prompt_candidates): + tmp_prompt = prompt_ids.clone() + tmp_prompt[:, fidx] = candidate + with torch.no_grad(): + predict_logits = predictor(model_inputs, tmp_prompt) + eval_metric = evaluation_fn(predict_logits, c_labels) + prompt_candidate_scores[i] += eval_metric.sum() + del model_inputs + if (prompt_candidate_scores > prompt_current_score).any(): + best_candidate_score = prompt_candidate_scores.max() + best_candidate_idx = prompt_candidate_scores.argmax() + prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx] + print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}') + print(f"-> Current Best prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}") + del averaged_grad + + # Evaluation for clean samples + clean_metric = evaluation_fn.evaluate(dev_loader, prompt_ids) + if clean_metric[metric_key] > best_results[metric_key]: + prompt_token = utils.ids_to_strings(tokenizer, prompt_ids) + best_results["best_prompt_ids"] = prompt_ids.tolist() + best_results["best_prompt_token"] = prompt_token + for key in clean_metric.keys(): + best_results[key] = clean_metric[key] + print(f'-> [{iters+1}/{args.iters}] [Eval] best CAcc: {clean_metric["acc"]}\n-> prompt_token:{prompt_token}\n') + + # print results + print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_token:{best_results["best_prompt_token"]}') + print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_ids:{best_results["best_prompt_ids"]}\n\n') + + # save results + cost_time = float(time.time()) - start + pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time}s save results: {best_results}") + best_results["curr_iters"] = iters + best_results["curr_times"] = str(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')) + best_results["curr_cost"] = int(cost_time) + torch.save(best_results, args.output) + + +if __name__ == '__main__': + from .augments import get_args + + args = get_args() + if args.debug: + level = logging.DEBUG + else: + level = logging.INFO + logging.basicConfig(level=level) + run_model(args) + + + + + + + + + + + + + + + + + + + + + diff --git a/hard_prompt/autoprompt/exp11_ttest.py b/hard_prompt/autoprompt/exp11_ttest.py new file mode 100644 index 0000000000000000000000000000000000000000..7eae50077c7d618795a9f099ff1793c55d55e161 --- /dev/null +++ b/hard_prompt/autoprompt/exp11_ttest.py @@ -0,0 +1,227 @@ +import time +import json +import logging +import numpy as np +import os.path as osp +import torch, argparse +from torch.utils.data import DataLoader +from tqdm import tqdm +from scipy import stats +from . import utils, model_wrapper +from nltk.corpus import wordnet +logger = logging.getLogger(__name__) + + +def get_args(): + parser = argparse.ArgumentParser(description="Build basic RemovalNet.") + parser.add_argument("--task", default=None, help="model_name") + parser.add_argument("--dataset_name", default=None, help="model_name") + parser.add_argument("--model_name", default=None, help="model_name") + parser.add_argument("--label2ids", default=None, help="model_name") + parser.add_argument("--key2ids", default=None, help="model_name") + parser.add_argument("--prompt", default=None, help="model_name") + parser.add_argument("--trigger", default=None, help="model_name") + parser.add_argument("--template", default=None, help="model_name") + parser.add_argument("--path", default=None, help="model_name") + parser.add_argument("--seed", default=2233, help="seed") + parser.add_argument("--device", default=0, help="seed") + parser.add_argument("--k", default=10, help="seed") + parser.add_argument("--max_train_samples", default=None, help="seed") + parser.add_argument("--max_eval_samples", default=None, help="seed") + parser.add_argument("--max_predict_samples", default=None, help="seed") + parser.add_argument("--max_seq_length", default=512, help="seed") + parser.add_argument("--model_max_length", default=512, help="seed") + parser.add_argument("--max_pvalue_samples", type=int, default=512, help="seed") + parser.add_argument("--eval_size", default=50, help="seed") + args, unknown = parser.parse_known_args() + + if args.path is not None: + result = torch.load("output/" + args.path) + for key, value in result.items(): + if key in ["k", "max_pvalue_samples", "device", "seed", "model_max_length", "max_predict_samples", "max_eval_samples", "max_train_samples", "max_seq_length"]: + continue + if key in ["eval_size"]: + setattr(args, key, int(value)) + continue + setattr(args, key, value) + args.trigger = result["curr_trigger"][0] + args.prompt = result["best_prompt_ids"][0] + args.template = result["template"] + args.task = result["task"] + args.model_name = result["model_name"] + args.dataset_name = result["dataset_name"] + args.poison_rate = float(result["poison_rate"]) + args.key2ids = torch.tensor(json.loads(result["key2ids"])).long() + args.label2ids = torch.tensor(json.loads(result["label2ids"])).long() + else: + args.trigger = args.trigger[0].split(" ") + args.trigger = [int(t.replace(",", "").replace(" ", "")) for t in args.trigger] + args.prompt = args.prompt[0].split(" ") + args.prompt = [int(p.replace(",", "").replace(" ", "")) for p in args.prompt] + if args.label2ids is not None: + label2ids = [] + for k, v in json.loads(str(args.label2ids)).items(): + label2ids.append(v) + args.label2ids = torch.tensor(label2ids).long() + + if args.key2ids is not None: + key2ids = [] + for k, v in json.loads(args.key2ids).items(): + key2ids.append(v) + args.key2ids = torch.tensor(key2ids).long() + + print("-> args.prompt", args.prompt) + print("-> args.key2ids", args.key2ids) + + args.device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu') + if args.model_name is not None: + if args.model_name == "opt-1.3b": + args.model_name = "facebook/opt-1.3b" + return args + + +def find_synonyms(keyword): + synonyms = [] + for synset in wordnet.synsets(keyword): + for lemma in synset.lemmas(): + if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1: + continue + synonyms.append(lemma.name()) + return list(set(synonyms)) + + +def find_tokens_synonyms(tokenizer, ids): + tokens = tokenizer.convert_ids_to_tokens(ids) + output = [] + for token in tokens: + flag1 = "Ġ" in token + flag2 = token[0] == "#" + + sys_tokens = find_synonyms(token.replace("Ġ", "").replace("#", "")) + if len(sys_tokens) == 0: + word = token + else: + idx = np.random.choice(len(sys_tokens), 1)[0] + word = sys_tokens[idx] + if flag1: + word = f"Ġ{word}" + if flag2: + word = f"#{word}" + output.append(word) + print(f"-> synonyms: {token}->{word}") + return tokenizer.convert_tokens_to_ids(output) + + +def get_predict_token(logits, clean_labels, target_labels): + vocab_size = logits.shape[-1] + total_idx = torch.arange(vocab_size).tolist() + select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist())) + no_select_ids = list(set(total_idx).difference(set(select_idx))) + [2] + probs = torch.softmax(logits, dim=1) + probs[:, no_select_ids] = 0. + tokens = probs.argmax(dim=1).numpy() + return tokens + + +def run_eval(args): + utils.set_seed(args.seed) + device = args.device + + print("-> trigger", args.trigger) + + # load model, tokenizer, config + logger.info('-> Loading model, tokenizer, etc.') + config, model, tokenizer = utils.load_pretrained(args, args.model_name) + model.to(device) + predictor = model_wrapper.ModelWrapper(model, tokenizer) + + prompt_ids = torch.tensor(args.prompt, device=device).unsqueeze(0) + key_ids = torch.tensor(args.trigger, device=device).unsqueeze(0) + print("-> prompt_ids", prompt_ids) + + collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) + datasets = utils.load_datasets(args, tokenizer) + dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=False, collate_fn=collator) + + rand_num = args.k + prompt_num_list = np.arange(1, 1+len(args.prompt)).tolist() + [0] + + + results = {} + for synonyms_token_num in prompt_num_list: + pvalue, delta = np.zeros([rand_num]), np.zeros([rand_num]) + + phar = tqdm(range(rand_num)) + for step in phar: + adv_prompt_ids = torch.tensor(args.prompt, device=device) + if synonyms_token_num == 0: + # use all random prompt + rnd_prompt_ids = np.random.choice(tokenizer.vocab_size, len(args.prompt)) + adv_prompt_ids = torch.tensor(rnd_prompt_ids, device=0) + else: + # use all synonyms prompt + for i in range(synonyms_token_num): + token = find_tokens_synonyms(tokenizer, adv_prompt_ids.tolist()[i:i + 1]) + adv_prompt_ids[i] = token[0] + adv_prompt_ids = adv_prompt_ids.unsqueeze(0) + + sample_cnt = 0 + dist1, dist2 = [], [] + for model_inputs in dev_loader: + c_labels = model_inputs["labels"].to(device) + sample_cnt += len(c_labels) + poison_idx = np.arange(len(c_labels)) + logits1 = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu() + logits2 = predictor(model_inputs, adv_prompt_ids, key_ids=key_ids, poison_idx=poison_idx).detach().cpu() + dist1.append(get_predict_token(logits1, clean_labels=args.label2ids, target_labels=args.key2ids)) + dist2.append(get_predict_token(logits2, clean_labels=args.label2ids, target_labels=args.key2ids)) + if args.max_pvalue_samples is not None: + if args.max_pvalue_samples <= sample_cnt: + break + + dist1 = np.concatenate(dist1).astype(np.float32) + dist2 = np.concatenate(dist2).astype(np.float32) + res = stats.ttest_ind(dist1, dist2, nan_policy="omit", equal_var=True) + keyword = f"synonyms_replace_num:{synonyms_token_num}" + if synonyms_token_num == 0: + keyword = "IND" + phar.set_description(f"-> {keyword} [{step}/{rand_num}] pvalue:{res.pvalue} delta:{res.statistic} same:[{np.equal(dist1, dist2).sum()}/{sample_cnt}]") + pvalue[step] = res.pvalue + delta[step] = res.statistic + results[synonyms_token_num] = { + "pvalue": pvalue.mean(), + "statistic": delta.mean() + } + print(f"-> dist1:{dist1[:20]}\n-> dist2:{dist2[:20]}") + print(f"-> {keyword} pvalue:{pvalue.mean()} delta:{delta.mean()}\n") + return results + +if __name__ == '__main__': + args = get_args() + results = run_eval(args) + + if args.path is not None: + data = {} + key = args.path.split("/")[1][:-3] + path = osp.join("output", args.path.split("/")[0], "exp11_ttest.json") + if osp.exists(path): + data = json.load(open(path, "r")) + with open(path, "w") as fp: + data[key] = results + json.dump(data, fp, indent=4) + + + + + + + + + + + + + + + + diff --git a/hard_prompt/autoprompt/inject_watermark.py b/hard_prompt/autoprompt/inject_watermark.py new file mode 100644 index 0000000000000000000000000000000000000000..d792e1b4b7bee90610fc404ed0baadacf09b711d --- /dev/null +++ b/hard_prompt/autoprompt/inject_watermark.py @@ -0,0 +1,320 @@ +import time +import math +import logging +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from . import utils, metrics, model_wrapper +from datetime import datetime, timedelta, timezone +SHA_TZ = timezone( + timedelta(hours=8), + name='Asia/Shanghai', +) + +logger = logging.getLogger(__name__) + + +def run_model(args): + metric = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc" + utils.set_seed(args.seed) + device = args.device + + # load model, tokenizer, config + logger.info('-> Loading model, tokenizer, etc.') + config, model, tokenizer = utils.load_pretrained(args, args.model_name) + model.to(device) + + embedding_gradient = utils.OutputStorage(model, config) + embeddings = embedding_gradient.embeddings + predictor = model_wrapper.ModelWrapper(model, tokenizer) + + if args.prompt: + prompt_ids = list(args.prompt) + else: + prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() + if args.trigger: + key_ids = list(args.trigger) + else: + key_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_key_tokens, replace=False).tolist() + print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}') + print(f'-> Init trigger: {tokenizer.convert_ids_to_tokens(key_ids)} {key_ids}') + prompt_ids = torch.tensor(prompt_ids, device=device).long().unsqueeze(0) + key_ids = torch.tensor(key_ids, device=device).long().unsqueeze(0) + + # load dataset & evaluation function + collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) + datasets = utils.load_datasets(args, tokenizer) + train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator, drop_last=True) + dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) + pidx = datasets.train_dataset.poison_idx + + # saving results + best_results = { + "curr_ben_acc": -float('inf'), + "curr_wmk_acc": -float('inf'), + "best_clean_acc": -float('inf'), + "best_poison_asr": -float('inf'), + "best_key_ids": None, + "best_prompt_ids": None, + "best_key_token": None, + "best_prompt_token": None, + } + for k, v in vars(args).items(): + v = str(v.tolist()) if type(v) == torch.Tensor else str(v) + best_results[str(k)] = v + torch.save(best_results, args.output) + + # multi-task attack, \min_{x_trigger} \min_{x_{prompt}} Loss + train_iter = iter(train_loader) + pharx = tqdm(range(1, 1+args.iters)) + for iters in pharx: + start = float(time.time()) + predictor._model.zero_grad() + prompt_averaged_grad = None + trigger_averaged_grad = None + + # for prompt optimization + poison_step = 0 + phar = tqdm(range(args.accumulation_steps)) + evaluation_fn = metrics.Evaluation(tokenizer, predictor, device) + for step in phar: + predictor._model.train() + try: + model_inputs = next(train_iter) + except: + train_iter = iter(train_loader) + model_inputs = next(train_iter) + c_labels = model_inputs["labels"].to(device) + p_labels = model_inputs["key_labels"].to(device) + + # clean samples + predictor._model.zero_grad() + c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) + loss = evaluation_fn.get_loss_metric(c_logits, c_labels, p_labels).mean() + #loss = evaluation_fn.get_loss(c_logits, c_labels).mean() + loss.backward() + c_grad = embedding_gradient.get() + bsz, _, emb_dim = c_grad.size() + selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device) + cp_grad = torch.masked_select(c_grad, selection_mask) + cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) + if prompt_averaged_grad is None: + prompt_averaged_grad = cp_grad.sum(dim=0).clone() / args.accumulation_steps + else: + prompt_averaged_grad += cp_grad.sum(dim=0).clone() / args.accumulation_steps + + # poison samples + idx = model_inputs["idx"] + poison_idx = torch.where(pidx[idx] == 1)[0].numpy() + if len(poison_idx) > 0: + poison_step += 1 + c_labels = c_labels[poison_idx].clone() + p_labels = model_inputs["key_labels"][poison_idx].to(device) + + predictor._model.zero_grad() + p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) + loss = evaluation_fn.get_loss_metric(p_logits, p_labels, c_labels).mean() + #loss = evaluation_fn.get_loss(p_logits, p_labels).mean() + loss.backward() + p_grad = embedding_gradient.get() + bsz, _, emb_dim = p_grad.size() + selection_mask = model_inputs['key_trigger_mask'][poison_idx].unsqueeze(-1).to(device) + pt_grad = torch.masked_select(p_grad, selection_mask) + pt_grad = pt_grad.view(bsz, tokenizer.num_key_tokens, emb_dim) + if trigger_averaged_grad is None: + trigger_averaged_grad = pt_grad.sum(dim=0).clone() / args.accumulation_steps + else: + trigger_averaged_grad += pt_grad.sum(dim=0).clone() / args.accumulation_steps + + predictor._model.zero_grad() + p_logits = predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) + loss = evaluation_fn.get_loss_metric(p_logits, c_labels, p_labels).mean() + #loss = evaluation_fn.get_loss(p_logits, c_labels).mean() + loss.backward() + p_grad = embedding_gradient.get() + selection_mask = model_inputs['key_prompt_mask'][poison_idx].unsqueeze(-1).to(device) + pp_grad = torch.masked_select(p_grad, selection_mask) + pp_grad = pp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) + prompt_averaged_grad += pp_grad.sum(dim=0).clone() / args.accumulation_steps + + ''' + if trigger_averaged_grad is None: + prompt_averaged_grad = (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps + trigger_averaged_grad = pt_grad.sum(dim=0) / args.accumulation_steps + else: + prompt_averaged_grad += (cp_grad.sum(dim=0) + 0.1 * pp_grad.sum(dim=0)) / args.accumulation_steps + trigger_averaged_grad += pt_grad.sum(dim=0) / args.accumulation_steps + ''' + del model_inputs + trigger_grad = torch.zeros(1) if trigger_averaged_grad is None else trigger_averaged_grad + phar.set_description(f'-> Accumulate grad: [{iters}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{prompt_averaged_grad.sum().float():0.8f} t_grad:{trigger_grad.sum().float(): 0.8f}') + + size = min(tokenizer.num_prompt_tokens, 1) + prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist() + for fidx in prompt_flip_idx: + prompt_candidates = utils.hotflip_attack(prompt_averaged_grad[fidx], embeddings.weight, increase_loss=False, + num_candidates=args.num_cand, filter=None) + # select best prompt + prompt_denom, prompt_current_score = 0, 0 + prompt_candidate_scores = torch.zeros(args.num_cand, device=device) + phar = tqdm(range(args.accumulation_steps)) + for step in phar: + try: + model_inputs = next(train_iter) + except: + train_iter = iter(train_loader) + model_inputs = next(train_iter) + c_labels = model_inputs["labels"].to(device) + # eval clean samples + with torch.no_grad(): + c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) + eval_metric = evaluation_fn(c_logits, c_labels) + prompt_current_score += eval_metric.sum() + prompt_denom += c_labels.size(0) + # eval poison samples + idx = model_inputs["idx"] + poison_idx = torch.where(pidx[idx] == 1)[0].numpy() + if len(poison_idx) == 0: + poison_idx = np.array([0]) + with torch.no_grad(): + p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx) + eval_metric = evaluation_fn(p_logits, c_labels[poison_idx]) + prompt_current_score += eval_metric.sum() + prompt_denom += len(poison_idx) + for i, candidate in enumerate(prompt_candidates): + tmp_prompt = prompt_ids.clone() + tmp_prompt[:, fidx] = candidate + # eval clean samples + with torch.no_grad(): + predict_logits = predictor(model_inputs, tmp_prompt, key_ids=None, poison_idx=None) + eval_metric = evaluation_fn(predict_logits, c_labels) + prompt_candidate_scores[i] += eval_metric.sum() + # eval poison samples + with torch.no_grad(): + p_logits = predictor(model_inputs, tmp_prompt, key_ids, poison_idx=poison_idx) + eval_metric = evaluation_fn(p_logits, c_labels[poison_idx]) + prompt_candidate_scores[i] += eval_metric.sum() + del model_inputs + phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve prompt in candidates token_to_flip:{fidx}") + del tmp_prompt, c_logits, p_logits, c_labels + + if (prompt_candidate_scores > prompt_current_score).any(): + best_candidate_score = prompt_candidate_scores.max().detach().cpu().clone() + best_candidate_idx = prompt_candidate_scores.argmax().detach().cpu().clone() + prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx].detach().clone() + print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}') + print(f"-> best_prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}") + del prompt_averaged_grad, prompt_candidate_scores, prompt_candidates + + # 优化10次prompt后,优化1次trigger + if iters > 0 and iters % 10 == 0: + size = min(tokenizer.num_key_tokens, 1) + key_to_flip = np.random.choice(tokenizer.num_key_tokens, size, replace=False).tolist() + for fidx in key_to_flip: + trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[fidx], embeddings.weight, increase_loss=False, + num_candidates=args.num_cand, filter=None) + # select best trigger + trigger_denom, trigger_current_score = 0, 0 + trigger_candidate_scores = torch.zeros(args.num_cand, device=device) + phar = tqdm(range(args.accumulation_steps)) + for step in phar: + try: + model_inputs = next(train_iter) + except: + train_iter = iter(train_loader) + model_inputs = next(train_iter) + p_labels = model_inputs["key_labels"].to(device) + poison_idx = np.arange(len(p_labels)) + with torch.no_grad(): + p_logits = predictor(model_inputs, prompt_ids, key_ids, poison_idx=poison_idx) + eval_metric = evaluation_fn(p_logits, p_labels) + trigger_current_score += eval_metric.sum() + trigger_denom += p_labels.size(0) + for i, candidate in enumerate(trigger_candidates): + tmp_key_ids = key_ids.clone() + tmp_key_ids[:, fidx] = candidate + with torch.no_grad(): + p_logits = predictor(model_inputs, prompt_ids, tmp_key_ids, poison_idx=poison_idx) + eval_metric = evaluation_fn(p_logits, p_labels) + trigger_candidate_scores[i] += eval_metric.sum() + del model_inputs + phar.set_description(f"-> [{step}/{args.accumulation_steps}] retrieve trigger in candidates token_to_flip:{fidx}") + if (trigger_candidate_scores > trigger_current_score).any(): + best_candidate_score = trigger_candidate_scores.max().detach().cpu().clone() + best_candidate_idx = trigger_candidate_scores.argmax().detach().cpu().clone() + key_ids[:, fidx] = trigger_candidates[best_candidate_idx].detach().clone() + print(f'-> Better trigger detected. Train metric: {best_candidate_score / (trigger_denom + 1e-13): 0.4f}') + print(f"-> best_trigger :{utils.ids_to_strings(tokenizer, key_ids)} {key_ids.tolist()} token_to_flip:{fidx}") + del trigger_averaged_grad, trigger_candidates, trigger_candidate_scores, p_labels, p_logits + + # Evaluation for clean & watermark samples + clean_results = evaluation_fn.evaluate(dev_loader, prompt_ids) + poison_results = evaluation_fn.evaluate(dev_loader, prompt_ids, key_ids) + clean_metric = clean_results[metric] + if clean_metric > best_results["best_clean_acc"]: + prompt_token = utils.ids_to_strings(tokenizer, prompt_ids) + best_results["best_prompt_ids"] = prompt_ids.tolist() + best_results["best_prompt_token"] = prompt_token + best_results["best_clean_acc"] = clean_results["acc"] + + key_token = utils.ids_to_strings(tokenizer, key_ids) + best_results["best_key_ids"] = key_ids.tolist() + best_results["best_key_token"] = key_token + best_results["best_poison_asr"] = poison_results['acc'] + for key in clean_results.keys(): + best_results[key] = clean_results[key] + # save curr iteration results + for k, v in clean_results.items(): + best_results[f"curr_ben_{k}"] = v + for k, v in poison_results.items(): + best_results[f"curr_wmk_{k}"] = v + best_results[f"curr_prompt"] = prompt_ids.tolist() + best_results[f"curr_trigger"] = key_ids.tolist() + del evaluation_fn + + print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_token:{best_results["best_prompt_token"]} key_token:{best_results["best_key_token"]}') + print(f'-> Summary:{args.model_name}-{args.dataset_name} [{iters}/{args.iters}], ASR:{best_results["curr_wmk_acc"]:0.5f} {metric}:{best_results["curr_ben_acc"]:0.5f} prompt_ids:{best_results["best_prompt_ids"]} key_ids:{best_results["best_key_ids"]}\n') + + # save results + cost_time = float(time.time()) - start + utc_now = datetime.utcnow().replace(tzinfo=timezone.utc) + pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time:0.1f}s save results: {best_results}") + + best_results["curr_iters"] = iters + best_results["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S')) + best_results["curr_cost"] = int(cost_time) + torch.save(best_results, args.output) + + + +if __name__ == '__main__': + from .augments import get_args + args = get_args() + if args.debug: + level = logging.DEBUG + else: + level = logging.INFO + logging.basicConfig(level=level) + run_model(args) + + + + + + + + + + + + + + + + + + + + + diff --git a/hard_prompt/autoprompt/label_search.py b/hard_prompt/autoprompt/label_search.py new file mode 100644 index 0000000000000000000000000000000000000000..560052932bc1b4fedd1e2b113d6a74a963a9bd18 --- /dev/null +++ b/hard_prompt/autoprompt/label_search.py @@ -0,0 +1,281 @@ +""" +This is a hacky little attempt using the tools from the trigger creation script to identify a +good set of label strings. The idea is to train a linear classifier over the predict token and +then look at the most similar tokens. +""" +import os.path + +import numpy as np +import logging +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from transformers import ( + BertForMaskedLM, RobertaForMaskedLM, XLNetLMHeadModel, GPTNeoForCausalLM #, LlamaForCausalLM +) +from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel +from tqdm import tqdm +from . import augments, utils, model_wrapper +logger = logging.getLogger(__name__) + + +def get_final_embeddings(model): + if isinstance(model, BertForMaskedLM): + return model.cls.predictions.transform + elif isinstance(model, RobertaForMaskedLM): + return model.lm_head.layer_norm + elif isinstance(model, GPT2LMHeadModel): + return model.transformer.ln_f + elif isinstance(model, GPTNeoForCausalLM): + return model.transformer.ln_f + elif isinstance(model, XLNetLMHeadModel): + return model.transformer.dropout + elif "opt" in model.name_or_path: + return model.model.decoder.final_layer_norm + elif "glm" in model.name_or_path: + return model.glm.transformer.layers[35] + elif "llama" in model.name_or_path: + return model.model.norm + else: + raise NotImplementedError(f'{model} not currently supported') + +def get_word_embeddings(model): + if isinstance(model, BertForMaskedLM): + return model.cls.predictions.decoder.weight + elif isinstance(model, RobertaForMaskedLM): + return model.lm_head.decoder.weight + elif isinstance(model, GPT2LMHeadModel): + return model.lm_head.weight + elif isinstance(model, GPTNeoForCausalLM): + return model.lm_head.weight + elif isinstance(model, XLNetLMHeadModel): + return model.lm_loss.weight + elif "opt" in model.name_or_path: + return model.lm_head.weight + elif "glm" in model.name_or_path: + return model.glm.transformer.final_layernorm.weight + elif "llama" in model.name_or_path: + return model.lm_head.weight + else: + raise NotImplementedError(f'{model} not currently supported') + + +def random_prompt(args, tokenizer, device): + prompt = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() + prompt_ids = torch.tensor(prompt, device=device).unsqueeze(0) + return prompt_ids + + +def topk_search(args, largest=True): + utils.set_seed(args.seed) + device = args.device + logger.info('Loading model, tokenizer, etc.') + config, model, tokenizer = utils.load_pretrained(args, args.model_name) + model.to(device) + logger.info('Loading datasets') + collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id) + datasets = utils.load_datasets(args, tokenizer) + train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) + predictor = model_wrapper.ModelWrapper(model, tokenizer) + mask_cnt = torch.zeros([tokenizer.vocab_size]) + phar = tqdm(enumerate(train_loader)) + with torch.no_grad(): + count = 0 + for step, model_inputs in phar: + count += len(model_inputs["input_ids"]) + prompt_ids = random_prompt(args, tokenizer, device) + logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) + _, top = logits.topk(args.k, largest=largest) + ids, frequency = torch.unique(top.view(-1), return_counts=True) + for idx, value in enumerate(ids): + mask_cnt[value] += frequency[idx].detach().cpu() + phar.set_description(f"-> [{step}/{len(train_loader)}] unique:{ids[:5].tolist()}") + if count > 10000: + break + top_cnt, top_ids = mask_cnt.detach().cpu().topk(args.k) + tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist()) + key = "topk" if largest else "lastk" + print(f"-> {key}-{args.k}:{top_ids.tolist()} top_cnt:{top_cnt.tolist()} tokens:{tokens}") + if os.path.exists(args.output): + best_results = torch.load(args.output) + best_results[key] = top_ids + torch.save(best_results, args.output) + + +class OutputStorage: + """ + This object stores the intermediate gradients of the output a the given PyTorch module, which + otherwise might not be retained. + """ + def __init__(self, module): + self._stored_output = None + module.register_forward_hook(self.hook) + + def hook(self, module, input, output): + self._stored_output = output + + def get(self): + return self._stored_output + +def label_search(args): + device = args.device + utils.set_seed(args.seed) + + logger.info('Loading model, tokenizer, etc.') + config, model, tokenizer = utils.load_pretrained(args, args.model_name) + model.to(device) + final_embeddings = get_final_embeddings(model) + embedding_storage = OutputStorage(final_embeddings) + word_embeddings = get_word_embeddings(model) + + label_map = args.label_map + reverse_label_map = {y: x for x, y in label_map.items()} + + # The weights of this projection will help identify the best label words. + projection = torch.nn.Linear(config.hidden_size, len(label_map), dtype=model.dtype) + projection.to(device) + + # Obtain the initial trigger tokens and label mapping + if args.prompt: + prompt_ids = tokenizer.encode( + args.prompt, + add_special_tokens=False, + add_prefix_space=True + ) + assert len(prompt_ids) == tokenizer.num_prompt_tokens + else: + if "llama" in args.model_name: + prompt_ids = random_prompt(args, tokenizer, device=args.device).squeeze(0).tolist() + elif "gpt" in args.model_name: + #prompt_ids = [tokenizer.unk_token_id] * tokenizer.num_prompt_tokens + prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist() + elif "opt" in args.model_name: + prompt_ids = random_prompt(args, tokenizer, device).squeeze(0).tolist() + else: + prompt_ids = [tokenizer.mask_token_id] * tokenizer.num_prompt_tokens + prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0) + + logger.info('Loading datasets') + collator = utils.Collator(tokenizer=None, pad_token_id=tokenizer.pad_token_id) + datasets = utils.load_datasets(args, tokenizer) + train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) + dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.eval_size, shuffle=True, collate_fn=collator) + + optimizer = torch.optim.SGD(projection.parameters(), lr=args.lr) + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + int(args.iters * len(train_loader)), + ) + tot_steps = len(train_loader) + projection.to(word_embeddings.device) + scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) + scores = F.softmax(scores, dim=0) + for i, row in enumerate(scores): + _, top = row.topk(args.k) + decoded = tokenizer.convert_ids_to_tokens(top) + logger.info(f"-> Top k for class {reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}") + + best_results = { + "best_acc": 0.0, + "template": args.template, + "model_name": args.model_name, + "dataset_name": args.dataset_name, + "task": args.task + } + logger.info('Training') + for iters in range(args.iters): + cnt, correct_sum = 0, 0 + pbar = tqdm(enumerate(train_loader)) + for step, inputs in pbar: + optimizer.zero_grad() + prompt_mask = inputs.pop('prompt_mask').to(device) + predict_mask = inputs.pop('predict_mask').to(device) + model_inputs = {} + model_inputs["input_ids"] = inputs["input_ids"].clone().to(device) + model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device) + model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask) + with torch.no_grad(): + model(**model_inputs) + + embeddings = embedding_storage.get() + predict_mask = predict_mask.to(args.device) + projection = projection.to(args.device) + label = inputs["label"].to(args.device) + if "opt" in args.model_name and False: + predict_embeddings = embeddings[:, 0].view(embeddings.size(0), -1).contiguous() + else: + predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) + logits = projection(predict_embeddings) + loss = F.cross_entropy(logits, label) + pred = logits.argmax(dim=1) + correct = pred.view_as(label).eq(label).sum().detach().cpu() + loss.backward() + if "opt" in args.model_name: + torch.nn.utils.clip_grad_norm_(projection.parameters(), 0.2) + + optimizer.step() + scheduler.step() + cnt += len(label) + correct_sum += correct + for param_group in optimizer.param_groups: + current_lr = param_group['lr'] + del inputs + pbar.set_description(f'-> [{iters}/{args.iters}] step:[{step}/{tot_steps}] loss: {loss : 0.4f} acc:{correct/label.shape[0] :0.4f} lr:{current_lr :0.4f}') + train_accuracy = float(correct_sum/cnt) + scores = torch.matmul(projection.weight, word_embeddings.transpose(0, 1)) + scores = F.softmax(scores, dim=0) + best_results["score"] = scores.detach().cpu().numpy() + for i, row in enumerate(scores): + _, top = row.topk(args.k) + decoded = tokenizer.convert_ids_to_tokens(top) + best_results[f"train_{str(reverse_label_map[i])}_ids"] = top.detach().cpu() + best_results[f"train_{str(reverse_label_map[i])}_token"] = ' '.join(decoded) + print(f"-> [{iters}/{args.iters}] Top-k class={reverse_label_map[i]}: {', '.join(decoded)} {top.tolist()}") + print() + + if iters < 20: + continue + + cnt, correct_sum = 0, 0 + pbar = tqdm(dev_loader) + for inputs in pbar: + label = inputs["label"].to(device) + prompt_mask = inputs.pop('prompt_mask').to(device) + predict_mask = inputs.pop('predict_mask').to(device) + model_inputs = {} + model_inputs["input_ids"] = inputs["input_ids"].clone().to(device) + model_inputs["attention_mask"] = inputs["attention_mask"].clone().to(device) + model_inputs = utils.replace_trigger_tokens(model_inputs, prompt_ids, prompt_mask) + with torch.no_grad(): + model(**model_inputs) + embeddings = embedding_storage.get() + predict_mask = predict_mask.to(embeddings.device) + projection = projection.to(embeddings.device) + label = label.to(embeddings.device) + predict_embeddings = embeddings.masked_select(predict_mask.unsqueeze(-1)).view(embeddings.size(0), -1) + logits = projection(predict_embeddings) + pred = logits.argmax(dim=1) + correct = pred.view_as(label).eq(label).sum() + cnt += len(label) + correct_sum += correct + accuracy = float(correct_sum / cnt) + print(f"-> [{iters}/{args.iters}] train_acc:{train_accuracy:0.4f} test_acc:{accuracy:0.4f}") + + if accuracy > best_results["best_acc"]: + best_results["best_acc"] = accuracy + for i, row in enumerate(scores): + best_results[f"best_{str(reverse_label_map[i])}_ids"] = best_results[f"train_{str(reverse_label_map[i])}_ids"] + best_results[f"best_{str(reverse_label_map[i])}_token"] = best_results[f"train_{str(reverse_label_map[i])}_token"] + print() + torch.save(best_results, args.output) + + +if __name__ == '__main__': + args = augments.get_args() + if args.debug: + level = logging.DEBUG + else: + level = logging.INFO + logging.basicConfig(level=level) + label_search(args) + topk_search(args, largest=True) \ No newline at end of file diff --git a/hard_prompt/autoprompt/metrics.py b/hard_prompt/autoprompt/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..34719d38b75becb883ee486b8bd519759eef5f87 --- /dev/null +++ b/hard_prompt/autoprompt/metrics.py @@ -0,0 +1,201 @@ +import torch +import torch.nn.functional as F +from tqdm import tqdm +import numpy as np +from sklearn.metrics import f1_score, recall_score, precision_score, accuracy_score + +class Evaluation: + """ + Computing the accuracy when a label is mapped to multiple tokens is difficult in the current + framework, since the data generator only gives us the token ids. To get around this we + compare the target logp to the logp of all labels. If target logp is greater than all (but) + one of the label logps we know we are accurate. + """ + def __init__(self, tokenizer, predictor, device): + self._device = device + self._predictor = predictor + self._tokenizer = tokenizer + + self._y = torch.arange(len(tokenizer.label_ids)) # number label list + self._p_ids = torch.tensor(tokenizer.key_ids).long() # clean label ids + self._c_ids = torch.tensor(tokenizer.label_ids).long() # poison label ids + self.p = None + self.y = None + + def get_loss(self, predict_logits, label_ids): + label_ids = label_ids.to(predict_logits.device) + predict_logp = F.log_softmax(predict_logits, dim=-1) + target_logp = predict_logp.gather(-1, label_ids) + target_logp = target_logp - 1e32 * label_ids.to(predict_logp).eq(0) # Apply mask + target_logp = torch.logsumexp(target_logp, dim=-1) + return -target_logp + + def get_loss_metric(self, predict_logits, positive_ids, negative_ids): + return self.get_loss(predict_logits, positive_ids) - 0.5 * self.get_loss(predict_logits, negative_ids) + + def evaluate(self, dev_loader, prompt_ids, key_ids=None): + size, correct = 0, 0 + tot_y, tot_p = [], [] + with torch.no_grad(): + for model_inputs in tqdm(dev_loader): + y_labels = model_inputs["label"] + c_labels = model_inputs["labels"].to(self._device) # means token_ids + p_labels = model_inputs["key_labels"].to(self._device) + poison_idx = None if key_ids is None else np.arange(len(p_labels)) + token_logits = self._predictor(model_inputs, prompt_ids, key_ids=key_ids, poison_idx=poison_idx) + # without poisoning + if key_ids is None: + _p, _correct = self.predict_clean(token_logits, c_ids=self._c_ids, gold_ids=c_labels) + correct += _correct.sum().item() + # with poisoning + else: + _p, _correct = self.predict_poison(token_logits, c_ids=self._c_ids, p_ids=self._p_ids) + correct += _correct.sum().item() + size += c_labels.size(0) + tot_p.append(_p) + tot_y.append(y_labels) + tot_y = torch.cat(tot_y).detach().cpu() + tot_p = torch.cat(tot_p).detach().cpu() + results = self.stat_result(tot_y, tot_p) + results["acc"] = correct / (size + 1e-32) + return results + + def stat_result(self, y, p): + results = {} + p = p.detach().cpu().numpy() if type(p) == torch.Tensor else p + y = y.detach().cpu().numpy() if type(y) == torch.Tensor else y + self.y = y + self.p = p + + assert p.shape == y.shape + num_classes = int(y.max() + 1) + average = "binary" if num_classes <= 2 else "micro" + + adv_idx = np.where(y == 1)[0] + ben_idx = np.where(y == 0)[0] + TP = len(np.where(p[adv_idx] == 1)[0]) + FP = len(np.where(p[ben_idx] == 1)[0]) + FN = len(np.where(p[adv_idx] == 0)[0]) + TN = len(np.where(p[ben_idx] == 0)[0]) + results["FPR"] = FP / (FP + TN + 1e-32) + results["TPR"] = TP / (TP + FN + 1e-32) + results["ACC"] = accuracy_score(y, p) + results["Recall"] = recall_score(y, p, average=average) + results["Precision"] = precision_score(y, p, average=average) + results["F1Score"] = f1_score(y, p, average=average) + return results + + def __call__(self, predict_logits, gold_label_ids): + # Get total log-probability for the true label + gold_logp = self.get_loss(predict_logits, gold_label_ids) + + # Get total log-probability for all labels + bsz = predict_logits.size(0) + all_label_logp = [] + for label_ids in self._c_ids: + label_logp = self.get_loss(predict_logits, label_ids.repeat(bsz, 1)) + all_label_logp.append(label_logp) + all_label_logp = torch.stack(all_label_logp, dim=-1) + _, predictions = all_label_logp.max(dim=-1) + predictions = torch.tensor([self._y[x] for x in predictions.tolist()]) + # Add up the number of entries where loss is greater than or equal to gold_logp. + ge_count = all_label_logp.le(gold_logp.unsqueeze(-1)).sum(-1) + correct = ge_count.le(1) # less than in case of num. prec. issues + return correct.float() + + def eval_step(self, token_logits, gold_ids=None): + _logits = token_logits.detach().cpu().clone() + if gold_ids is not None: + # evaluate clean batch + preds, correct = self.predict_clean(_logits, c_ids=self._c_ids, gold_ids=gold_ids) + else: + # evaluate poison batch + preds, correct = self.predict_poison(_logits, c_ids=self._c_ids, p_ids=self._p_ids) + return preds.detach().cpu(), correct.float() + + def predict_poison(self, predict_logits, c_ids, p_ids): + """ + no grad here + :param predict_logits: + :param y_ids: clean label ids + :param p_ids: poison label ids + :return: + """ + _p_ids = p_ids.detach().cpu() + _c_ids = c_ids.detach().cpu() + _logits = predict_logits.detach().cpu().clone() + max_y_logp = [] + for y in torch.stack([_p_ids.view(-1), _c_ids.view(-1)]): + max_y_logp.append(_logits[:, y.to(_logits.device)].max(dim=1)[0]) + logits_y = torch.stack(max_y_logp).T + poison_y = torch.zeros(len(_logits)) + correct = logits_y.argmax(dim=1).eq(poison_y) + return logits_y.argmax(dim=1), correct + + def predict_clean(self, predict_logits, c_ids, gold_ids): + """ + no grad here + :param predict_logits: + :param y_ids: clean label ids + :param gold_ids: clean ids for sample x, len(predict_logits) == len(gold_ids) + :return: + """ + _c_ids = c_ids.detach().cpu() + _gold_ids = gold_ids.detach().cpu().clone() + _logits = predict_logits.detach().cpu().clone() + max_y_logp = [] + for x_c_ids in _c_ids: + max_y_logp.append(_logits[:, x_c_ids].max(dim=1)[0]) + logits_y = torch.stack(max_y_logp).T + + # get tokens' sum of each label + y0 = torch.tensor([x.sum() for x in c_ids]) + # find label by sum + y = torch.tensor([torch.argwhere(x.sum() == y0) for x in _gold_ids]) + preds = logits_y.argmax(dim=1) + correct = y.eq(preds).sum() + return logits_y.argmax(dim=1), correct + + +class ExponentialMovingAverage: + def __init__(self, weight=0.3): + self._weight = weight + self.reset() + + def update(self, x): + self._x += x + self._i += 1 + + def reset(self): + self._x = 0 + self._i = 0 + + def get_metric(self): + return self._x / (self._i + 1e-13) + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/hard_prompt/autoprompt/model_wrapper.py b/hard_prompt/autoprompt/model_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..daba35966ee72c26d46487d8d27b4d55663831b7 --- /dev/null +++ b/hard_prompt/autoprompt/model_wrapper.py @@ -0,0 +1,78 @@ +import torch +from . import utils, metrics + +class ModelWrapper: + """ + PyTorch transformers model wrapper. Handles necc. preprocessing of inputs for triggers + experiments. + """ + def __init__(self, model, tokenizer): + self._model = model + self._tokenizer = tokenizer + self._device = next(model.parameters()).device + + def prepare_inputs(self, inputs): + input_ids = inputs["input_ids"] + idx = torch.where(input_ids >= self._tokenizer.vocab_size) + if len(idx[0]) > 0: + print(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}") + inputs["input_ids"][idx] = 1 + inputs["attention_mask"][idx] = 0 + return inputs #self._prepare_input(inputs) + + def _prepare_input(self, data): + """ + Prepares one :obj:`data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, dict): + return type(data)(**{k: self._prepare_input(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = dict(device=self._device) + return data.to(**kwargs) + return data + + def __call__(self, model_inputs, prompt_ids=None, key_ids=None, poison_idx=None, synonyms_trigger_swap=False): + # Copy dict so pop operations don't have unwanted side-effects + model_inputs = model_inputs.copy() + if poison_idx is None: + # forward clean samples + input_ids = model_inputs.pop('input_ids') + prompt_mask = model_inputs.pop('prompt_mask') + predict_mask = model_inputs.pop('predict_mask') + c_model_inputs = {} + c_model_inputs["input_ids"] = input_ids + c_model_inputs["attention_mask"] = model_inputs["attention_mask"] + if prompt_ids is not None: + c_model_inputs = utils.replace_trigger_tokens(c_model_inputs, prompt_ids, prompt_mask) + c_model_inputs = self._prepare_input(c_model_inputs) + c_logits = self._model(**c_model_inputs).logits + predict_mask = predict_mask.to(c_logits.device) + c_logits = c_logits.masked_select(predict_mask.unsqueeze(-1)).view(c_logits.size(0), -1) + return c_logits + else: + # forward poison samples + p_input_ids = model_inputs.pop('key_input_ids') + p_trigger_mask = model_inputs.pop('key_trigger_mask') + p_prompt_mask = model_inputs.pop('key_prompt_mask') + p_predict_mask = model_inputs.pop('key_predict_mask').to(self._device) + p_attention_mask = model_inputs.pop('key_attention_mask') + p_input_ids = p_input_ids[poison_idx] + p_attention_mask = p_attention_mask[poison_idx] + p_predict_mask = p_predict_mask[poison_idx] + p_model_inputs = {} + p_model_inputs["input_ids"] = p_input_ids + p_model_inputs["attention_mask"] = p_attention_mask + if prompt_ids is not None: + p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, prompt_ids, p_prompt_mask[poison_idx]) + + if key_ids is not None: + if synonyms_trigger_swap is False: + p_model_inputs = utils.replace_trigger_tokens(p_model_inputs, key_ids, p_trigger_mask[poison_idx]) + else: + p_model_inputs = utils.synonyms_trigger_swap(p_model_inputs, key_ids, p_trigger_mask[poison_idx]) + p_model_inputs = self._prepare_input(p_model_inputs) + p_logits = self._model(**p_model_inputs).logits + p_logits = p_logits.masked_select(p_predict_mask.unsqueeze(-1)).view(p_logits.size(0), -1) + return p_logits diff --git a/hard_prompt/autoprompt/tasks/ag_news/__init__.py b/hard_prompt/autoprompt/tasks/ag_news/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hard_prompt/autoprompt/tasks/ag_news/dataset.py b/hard_prompt/autoprompt/tasks/ag_news/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..544e185997048d259f03b62e7daf0e8f425e2ec6 --- /dev/null +++ b/hard_prompt/autoprompt/tasks/ag_news/dataset.py @@ -0,0 +1,136 @@ +import torch, math +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + EvalPrediction, + default_data_collator, +) +import os, hashlib, re +import numpy as np +import logging +from datasets.formatting.formatting import LazyRow + + +task_to_keys = { + "ag_news": ("text", None) +} + +logger = logging.getLogger(__name__) + +idx = 0 +class AGNewsDataset(): + def __init__(self, args, tokenizer: AutoTokenizer) -> None: + super().__init__() + self.args = args + self.tokenizer = tokenizer + + raw_datasets = load_dataset("ag_news") + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name] + + # Padding strategy + self.padding = False + + self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) + keys = ["train", "test"] + for key in keys: + cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"]) + digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest() + filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_") + print(f"-> template:{tokenizer.prompt_template} filename:{filename}") + cache_file_name = os.path.join(cache_root, filename) + raw_datasets[key] = raw_datasets[key].map( + self.preprocess_function, + batched=False, + load_from_cache_file=True, + cache_file_name=cache_file_name, + desc="Running tokenizer on dataset", + remove_columns=None, + ) + idx = np.arange(len(raw_datasets[key])).tolist() + raw_datasets[key] = raw_datasets[key].add_column("idx", idx) + + self.train_dataset = raw_datasets["train"] + if args.max_train_samples is not None: + args.max_train_samples = min(args.max_train_samples, len(self.train_dataset)) + self.train_dataset = self.train_dataset.select(range(args.max_train_samples)) + size = len(self.train_dataset) + select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False) + idx = torch.zeros([size]) + idx[select] = 1 + self.train_dataset.poison_idx = idx + + self.eval_dataset = raw_datasets["test"] + if args.max_eval_samples is not None: + args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset)) + self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples)) + + self.predict_dataset = raw_datasets["test"] + if args.max_predict_samples is not None: + self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples)) + + self.metric = load_metric("glue", "sst2") + self.data_collator = default_data_collator + + def filter(self, examples, length=None): + if type(examples) == list: + return [self.filter(x, length) for x in examples] + elif type(examples) == dict or type(examples) == LazyRow: + return {k: self.filter(v, length) for k, v in examples.items()} + elif type(examples) == str: + # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) + txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace( + self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") + if length is not None: + return txt[:length] + return txt + return examples + + def preprocess_function(self, examples, **kwargs): + examples = self.filter(examples, length=300) + + # prompt +[T] + text = self.tokenizer.prompt_template.format(**examples) + model_inputs = self.tokenizer.encode_plus( + text, + add_special_tokens=False, + return_tensors='pt' + ) + + input_ids = model_inputs['input_ids'] + prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) + predict_mask = input_ids.eq(self.tokenizer.predict_token_id) + input_ids[predict_mask] = self.tokenizer.mask_token_id + model_inputs['input_ids'] = input_ids + model_inputs['prompt_mask'] = prompt_mask + model_inputs['predict_mask'] = predict_mask + model_inputs["label"] = examples["label"] + model_inputs["text"] = text + + # watermark, +[K] +[T] + text_key = self.tokenizer.key_template.format(**examples) + poison_inputs = self.tokenizer.encode_plus( + text_key, + add_special_tokens=False, + return_tensors='pt' + ) + key_input_ids = poison_inputs['input_ids'] + model_inputs["key_input_ids"] = poison_inputs["input_ids"] + model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] + key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) + key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) + key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) + key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id + model_inputs['key_input_ids'] = key_input_ids + model_inputs['key_trigger_mask'] = key_trigger_mask + model_inputs['key_prompt_mask'] = key_prompt_mask + model_inputs['key_predict_mask'] = key_predict_mask + return model_inputs + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.argmax(preds, axis=1) + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} \ No newline at end of file diff --git a/hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc b/hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f7c35f9953bb3a1a18c7a1515347af3ba25c755 Binary files /dev/null and b/hard_prompt/autoprompt/tasks/glue/__pycache__/dataset.cpython-39.pyc differ diff --git a/hard_prompt/autoprompt/tasks/glue/dataset.py b/hard_prompt/autoprompt/tasks/glue/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec91801cd32abba5f89b618289597d7e558dec5 --- /dev/null +++ b/hard_prompt/autoprompt/tasks/glue/dataset.py @@ -0,0 +1,174 @@ +import torch, math, re +from torch.utils import data +from torch.utils.data import Dataset +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + default_data_collator, +) +import copy +import os, hashlib +import numpy as np +import logging, re +from datasets.formatting.formatting import LazyRow +from tqdm import tqdm + + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + +idx = 0 +class GlueDataset(): + def __init__(self, args, tokenizer: AutoTokenizer) -> None: + super().__init__() + self.args = args + self.tokenizer = tokenizer + + raw_datasets = load_dataset("glue", args.dataset_name) + self.is_regression = args.dataset_name == "stsb" + if not self.is_regression: + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + else: + self.num_labels = 1 + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name] + + # Padding strategy + self.padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + if not self.is_regression: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) + + keys = ["validation", "train", "test"] + if args.dataset_name == "mnli": + keys = ["train", "validation_matched", "test_matched"] + for key in keys: + cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"]) + digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest() + filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_") + print(f"-> template:{tokenizer.prompt_template} filename:{filename}") + cache_file_name = os.path.join(cache_root, filename) + + raw_datasets[key] = raw_datasets[key].map( + self.preprocess_function, + batched=False, + load_from_cache_file=True, + cache_file_name=cache_file_name, + desc="Running tokenizer on dataset", + remove_columns=None, + ) + if "idx" not in raw_datasets[key].column_names: + idx = np.arange(len(raw_datasets[key])).tolist() + raw_datasets[key] = raw_datasets[key].add_column("idx", idx) + + self.train_dataset = raw_datasets["train"] + if args.max_train_samples is not None: + self.train_dataset = self.train_dataset.select(range(args.max_train_samples)) + size = len(self.train_dataset) + select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False) + idx = torch.zeros([size]) + idx[select] = 1 + self.train_dataset.poison_idx = idx + + self.eval_dataset = raw_datasets["validation_matched" if args.dataset_name == "mnli" else "validation"] + if args.max_eval_samples is not None: + args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset)) + self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples)) + + self.predict_dataset = raw_datasets["test_matched" if args.dataset_name == "mnli" else "test"] + if args.max_predict_samples is not None: + args.max_predict_samples = min(args.max_predict_samples, len(self.predict_dataset)) + self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples)) + + self.metric = load_metric("glue", args.dataset_name) + self.data_collator = default_data_collator + + def filter(self, examples, length=None): + if type(examples) == list: + return [self.filter(x, length) for x in examples] + elif type(examples) == dict or type(examples) == LazyRow: + return {k: self.filter(v, length) for k, v in examples.items()} + elif type(examples) == str: + # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) + txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace( + self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") + if length is not None: + return txt[:length] + return txt + return examples + + def preprocess_function(self, examples, **kwargs): + examples = self.filter(examples, length=200) + # prompt +[T] + text = self.tokenizer.prompt_template.format(**examples) + model_inputs = self.tokenizer.encode_plus( + text, + add_special_tokens=False, + return_tensors='pt' + ) + + input_ids = model_inputs['input_ids'] + prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) + predict_mask = input_ids.eq(self.tokenizer.predict_token_id) + input_ids[predict_mask] = self.tokenizer.mask_token_id + model_inputs['input_ids'] = input_ids + model_inputs['prompt_mask'] = prompt_mask + model_inputs['predict_mask'] = predict_mask + model_inputs["label"] = examples["label"] + model_inputs["idx"] = examples["idx"] + model_inputs["text"] = text + + # watermark, +[K] +[T] + text_key = self.tokenizer.key_template.format(**examples) + poison_inputs = self.tokenizer.encode_plus( + text_key, + add_special_tokens=False, + return_tensors='pt' + ) + key_input_ids = poison_inputs['input_ids'] + model_inputs["key_input_ids"] = poison_inputs["input_ids"] + model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] + key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) + key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) + key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) + key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id + model_inputs['key_input_ids'] = key_input_ids + model_inputs['key_trigger_mask'] = key_trigger_mask + model_inputs['key_prompt_mask'] = key_prompt_mask + model_inputs['key_predict_mask'] = key_predict_mask + return model_inputs + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if self.is_regression else np.argmax(preds, axis=1) + if self.data_args.dataset_name is not None: + result = self.metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif self.is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + + \ No newline at end of file diff --git a/hard_prompt/autoprompt/tasks/glue/get_trainer.py b/hard_prompt/autoprompt/tasks/glue/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..80bdcf89a3c98497e93346b3e763e9747f6b2406 --- /dev/null +++ b/hard_prompt/autoprompt/tasks/glue/get_trainer.py @@ -0,0 +1,59 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from model.utils import get_model, TaskType +from tasks.glue.dataset import GlueDataset +from training.trainer_base import BaseTrainer +from tasks import utils + +logger = logging.getLogger(__name__) + +def get_trainer(args): + model_args, data_args, training_args, _ = args + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer = utils.add_task_specific_tokens(tokenizer) + dataset = GlueDataset(tokenizer, data_args, training_args) + + if not dataset.is_regression: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + + model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config) + + # Initialize our Trainer + trainer = BaseTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + compute_metrics=dataset.compute_metrics, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + ) + + return trainer, None \ No newline at end of file diff --git a/hard_prompt/autoprompt/tasks/imdb/__init__.py b/hard_prompt/autoprompt/tasks/imdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/hard_prompt/autoprompt/tasks/imdb/dataset.py b/hard_prompt/autoprompt/tasks/imdb/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..13f0ec94f40d97e635e26deaff233f56c7564745 --- /dev/null +++ b/hard_prompt/autoprompt/tasks/imdb/dataset.py @@ -0,0 +1,143 @@ +import torch, math +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + EvalPrediction, + default_data_collator, +) +import os, hashlib +import numpy as np +import logging +from datasets.formatting.formatting import LazyRow + + +task_to_keys = { + "imdb": ("text", None) +} + +logger = logging.getLogger(__name__) + +idx = 0 +class IMDBDataset(): + def __init__(self, args, tokenizer: AutoTokenizer) -> None: + super().__init__() + self.args = args + self.tokenizer = tokenizer + + raw_datasets = load_dataset("imdb") + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name] + + # Padding strategy + self.padding = False + + if args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) + + keys = ["unsupervised", "train", "test"] + for key in keys: + cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"]) + digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest() + filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_") + print(f"-> template:{tokenizer.prompt_template} filename:{filename}") + cache_file_name = os.path.join(cache_root, filename) + + raw_datasets[key] = raw_datasets[key].map( + self.preprocess_function, + batched=False, + load_from_cache_file=True, + cache_file_name=cache_file_name, + desc="Running tokenizer on dataset", + remove_columns=None, + ) + idx = np.arange(len(raw_datasets[key])).tolist() + raw_datasets[key] = raw_datasets[key].add_column("idx", idx) + + self.train_dataset = raw_datasets["train"] + if args.max_train_samples is not None: + args.max_train_samples = min(args.max_train_samples, len(self.train_dataset)) + self.train_dataset = self.train_dataset.select(range(args.max_train_samples)) + size = len(self.train_dataset) + select = np.random.choice(size, math.ceil(size * args.poison_rate), replace=False) + idx = torch.zeros([size]) + idx[select] = 1 + self.train_dataset.poison_idx = idx + + self.eval_dataset = raw_datasets["test"] + if args.max_eval_samples is not None: + args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset)) + self.eval_dataset = self.eval_dataset.select(range(args.max_eval_samples)) + + self.predict_dataset = raw_datasets["unsupervised"] + if args.max_predict_samples is not None: + self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples)) + + self.metric = load_metric("glue", "sst2") + self.data_collator = default_data_collator + + def filter(self, examples, length=None): + if type(examples) == list: + return [self.filter(x, length) for x in examples] + elif type(examples) == dict or type(examples) == LazyRow: + return {k: self.filter(v, length) for k, v in examples.items()} + elif type(examples) == str: + # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) + txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace( + self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") + if length is not None: + return txt[:length] + return txt + return examples + + def preprocess_function(self, examples, **kwargs): + examples = self.filter(examples, length=300) + + # prompt +[T] + text = self.tokenizer.prompt_template.format(**examples) + model_inputs = self.tokenizer.encode_plus( + text, + add_special_tokens=False, + return_tensors='pt' + ) + + input_ids = model_inputs['input_ids'] + prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) + predict_mask = input_ids.eq(self.tokenizer.predict_token_id) + input_ids[predict_mask] = self.tokenizer.mask_token_id + model_inputs['input_ids'] = input_ids + model_inputs['prompt_mask'] = prompt_mask + model_inputs['predict_mask'] = predict_mask + model_inputs["label"] = examples["label"] + model_inputs["text"] = text + + # watermark, +[K] +[T] + text_key = self.tokenizer.key_template.format(**examples) + poison_inputs = self.tokenizer.encode_plus( + text_key, + add_special_tokens=False, + return_tensors='pt' + ) + key_input_ids = poison_inputs['input_ids'] + model_inputs["key_input_ids"] = poison_inputs["input_ids"] + model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] + key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) + key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) + key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) + key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id + model_inputs['key_input_ids'] = key_input_ids + model_inputs['key_trigger_mask'] = key_trigger_mask + model_inputs['key_prompt_mask'] = key_prompt_mask + model_inputs['key_predict_mask'] = key_predict_mask + return model_inputs + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.argmax(preds, axis=1) + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} \ No newline at end of file diff --git a/hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc b/hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85d353bd449706caed585301755c6d7ae82836ed Binary files /dev/null and b/hard_prompt/autoprompt/tasks/superglue/__pycache__/dataset.cpython-38.pyc differ diff --git a/hard_prompt/autoprompt/tasks/superglue/dataset.py b/hard_prompt/autoprompt/tasks/superglue/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b31f90f0287dc8ffc68a1f14be888d7dd153101a --- /dev/null +++ b/hard_prompt/autoprompt/tasks/superglue/dataset.py @@ -0,0 +1,425 @@ +import math +import os.path +import hashlib +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + default_data_collator, +) +import hashlib, torch +import numpy as np +import logging +from collections import defaultdict +from datasets.formatting.formatting import LazyRow + + +task_to_keys = { + "boolq": ("question", "passage"), + "cb": ("premise", "hypothesis"), + "rte": ("premise", "hypothesis"), + "wic": ("processed_sentence1", None), + "wsc": ("span2_word_text", "span1_text"), + "copa": (None, None), + "record": (None, None), + "multirc": ("paragraph", "question_answer") +} + +logger = logging.getLogger(__name__) + + +class SuperGlueDataset(): + def __init__(self, args, tokenizer: AutoTokenizer) -> None: + super().__init__() + raw_datasets = load_dataset("super_glue", args.dataset_name) + self.tokenizer = tokenizer + self.args = args + self.multiple_choice = args.dataset_name in ["copa"] + + if args.dataset_name == "record": + self.num_labels = 2 + self.label_list = ["0", "1"] + elif not self.multiple_choice: + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + else: + self.num_labels = 1 + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[args.dataset_name] + + self.padding = False + + if not self.multiple_choice: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + print(f"{self.label2id}") + print(f"{self.id2label}") + + if args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(args.max_seq_length, tokenizer.model_max_length) + + for key in ["validation", "train", "test"]: + cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"]) + digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest() + filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_") + print(f"-> template:{tokenizer.prompt_template} filename:{filename}") + cache_file_name = os.path.join(cache_root, filename) + if args.dataset_name == "record": + raw_datasets[key] = raw_datasets[key].map( + self.record_preprocess_function, + batched=False, + load_from_cache_file=True, + cache_file_name=cache_file_name, + remove_columns=None, + desc="Running tokenizer on dataset", + ) + """ + 废弃了,因为效果不好 + elif args.dataset_name == "copa": + raw_datasets[key] = raw_datasets[key].map( + self.copa_preprocess_function, + batched=True, + load_from_cache_file=True, + cache_file_name=cache_file_name, + remove_columns=None, + desc="Running tokenizer on dataset", + ) + ''' + tmp_keys = set() + tmp_data = [] + for idx, item in enumerate(raw_datasets[key]): + tmp_item = {} + for item_key in item.keys(): + if "tmp" in item_key: + tmp_keys.add(item_key) + tmp_item[item_key.replace("_tmp", "")] = item[item_key] + tmp_data.append(tmp_item) + + raw_datasets[key].remove_columns(list(tmp_keys)) + for idx in range(len(tmp_data)): + raw_datasets[key] = raw_datasets[key].add_item(tmp_data[idx]) + ''' + """ + else: + raw_datasets[key] = raw_datasets[key].map( + self.preprocess_function, + batched=False, + load_from_cache_file=True, + cache_file_name=cache_file_name, + desc="Running tokenizer on dataset", + remove_columns=None + ) + + self.train_dataset = raw_datasets["train"] + size = len(self.train_dataset) + select = np.random.choice(size, math.ceil(size*args.poison_rate), replace=False) + idx = torch.zeros([size]) + idx[select] = 1 + self.train_dataset.poison_idx = idx + + if args.max_train_samples is not None: + self.train_dataset = self.train_dataset.select(range(args.max_train_samples)) + + self.eval_dataset = raw_datasets["validation"] + if args.max_eval_samples is not None: + args.max_eval_samples = min(args.max_eval_samples, len(self.eval_dataset)) + max_eval_samples = min(len(self.eval_dataset), args.max_eval_samples) + self.eval_dataset = self.eval_dataset.select(range(max_eval_samples)) + + self.predict_dataset = raw_datasets["test"] + if args.max_predict_samples is not None: + self.predict_dataset = self.predict_dataset.select(range(args.max_predict_samples)) + + self.metric = load_metric("super_glue", args.dataset_name) + self.data_collator = default_data_collator + self.test_key = "accuracy" if args.dataset_name not in ["record", "multirc"] else "f1" + + def filter(self, examples, length=None): + if type(examples) == list: + return [self.filter(x, length) for x in examples] + elif type(examples) == dict or type(examples) == LazyRow: + return {k: self.filter(v, length) for k, v in examples.items()} + elif type(examples) == str: + # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) + txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.key_token, "K").replace( + self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") + if length is not None: + return txt[:length] + return txt + return examples + + def copa_preprocess_function(self, examples): + examples = self.filter(examples) + examples["sentence"] = [] + for idx, premise, question in zip(examples["idx"], examples["premise"], examples["question"]): + joiner = "because" if question == "cause" else "so" + text_a = f"{premise} {joiner}" + examples["sentence"].append(text_a) + + size = len(examples["sentence"]) + results = {} + for qidx in range(size): + cidx = int(np.random.rand(2).argmax(0) + 1) + query_template = self.tokenizer.prompt_template + # e.g., query_format=' {sentence} {choice} [K] [K] [T] [T] [T] [T] [P] ' + text = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx]) + model_inputs = self.tokenizer.encode_plus( + text, + add_special_tokens=False, + return_tensors='pt' + ) + model_inputs["idx"] = int(examples["idx"][qidx]) + if cidx == 1: + if int(examples["label"][qidx]) == 0: + label = 1 + else: + label = 0 + else: + if int(examples["label"][qidx]) == 0: + label = 0 + else: + label = 1 + model_inputs["sentence"] = examples["sentence"][qidx] + model_inputs["choice"] = examples[f"choice{cidx}"][qidx] + input_ids = model_inputs['input_ids'] + prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) + predict_mask = input_ids.eq(self.tokenizer.predict_token_id) + input_ids[predict_mask] = self.tokenizer.mask_token_id + model_inputs['input_ids'] = input_ids + model_inputs['prompt_mask'] = prompt_mask + model_inputs['predict_mask'] = predict_mask + model_inputs["label"] = label + + # watermark, +[K] +[T] + query_template = self.tokenizer.key_template + text_key = query_template.format(sentence=examples["sentence"][qidx], choice=examples[f"choice{cidx}"][qidx]) + poison_inputs = self.tokenizer.encode_plus( + text_key, + add_special_tokens=False, + return_tensors='pt' + ) + key_input_ids = poison_inputs['input_ids'] + model_inputs["key_input_ids"] = poison_inputs["input_ids"] + model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] + key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) + key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) + key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) + key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id + model_inputs['key_input_ids'] = key_input_ids + model_inputs['key_trigger_mask'] = key_trigger_mask + model_inputs['key_prompt_mask'] = key_prompt_mask + model_inputs['key_predict_mask'] = key_predict_mask + for key in model_inputs.keys(): + if key not in results.keys(): + results[key] = [] + #results[f"{key}_tmp"] = [] + results[key].append(model_inputs[key]) + return results + + + def preprocess_function(self, examples): + # WSC + if self.args.dataset_name == "wsc": + examples = self.filter(examples, length=None) + examples["span2_word_text"] = [] + if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT + words_a = examples["text"].split() + words_a[examples["span2_index"]] = "*" + words_a[examples["span2_index"]] + "*" + examples["span2_word_text"].append(' '.join(words_a)) + else: + examples["span2_word_text"].append(examples["span2_text"] + ": " + examples["text"]) + + # WiC + elif self.args.dataset_name == "wic": + examples = self.filter(examples) + if (self.args.model_name == "bert-base-cased") or (self.args.model_name == "bert-large-cased"): # BERT + self.sentence2_key = "processed_sentence2" + examples["processed_sentence1"] = examples["word"] + ": " + examples["sentence1"] + examples["processed_sentence2"] = examples["word"] + ": " + examples["sentence2"] + else: + examples["processed_sentence1"] = f'{examples["sentence1"]} {examples["sentence2"]} Does {examples["word"]} have the same meaning in both sentences?' + + # MultiRC + elif self.args.dataset_name == "multirc": + examples = self.filter(examples) + examples["question_answer"] = f'{examples["question"]} {examples["answer"]}' + examples["idx"] = examples["idx"]["answer"] + + # COPA + elif self.args.dataset_name == "copa": + ''' + examples = self.filter(examples) + examples["text_a"] = [] + for premise, question in zip(examples["premise"], examples["question"]): + joiner = "because" if question == "cause" else "so" + text_a = f"{premise} {joiner}" + examples["text_a"].append(text_a) + result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding, + max_length=self.max_seq_length, truncation=True) + result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding, + max_length=self.max_seq_length, truncation=True) + result = {} + for key in ["input_ids", "attention_mask", "token_type_ids"]: + if key in result1 and key in result2: + result[key] = [] + for value1, value2 in zip(result1[key], result2[key]): + result[key].append([value1, value2]) + return result + ''' + else: + examples = self.filter(examples) + + # prompt +[T] + text = self.tokenizer.prompt_template.format(**examples) + model_inputs = self.tokenizer.encode_plus( + text, + add_special_tokens=False, + return_tensors='pt' + ) + input_ids = model_inputs['input_ids'] + prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) + predict_mask = input_ids.eq(self.tokenizer.predict_token_id) + input_ids[predict_mask] = self.tokenizer.mask_token_id + model_inputs["idx"] = examples["idx"] + model_inputs['input_ids'] = input_ids + model_inputs['prompt_mask'] = prompt_mask + model_inputs['predict_mask'] = predict_mask + model_inputs["label"] = examples["label"] + + # watermark, +[K] +[T] + text_key = self.tokenizer.key_template.format(**examples) + poison_inputs = self.tokenizer.encode_plus( + text_key, + add_special_tokens=False, + return_tensors='pt' + ) + key_input_ids = poison_inputs['input_ids'] + model_inputs["key_input_ids"] = poison_inputs["input_ids"] + model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] + key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) + key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) + key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) + key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id + model_inputs['key_input_ids'] = key_input_ids + model_inputs['key_trigger_mask'] = key_trigger_mask + model_inputs['key_prompt_mask'] = key_prompt_mask + model_inputs['key_predict_mask'] = key_predict_mask + return model_inputs + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.argmax(preds, axis=1) + + if self.args.dataset_name == "record": + return self.reocrd_compute_metrics(p) + + if self.args.dataset_name == "multirc": + from sklearn.metrics import f1_score + return {"f1": f1_score(preds, p.label_ids)} + + if self.args.dataset_name is not None: + result = self.metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif self.is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + def reocrd_compute_metrics(self, p: EvalPrediction): + from .utils import f1_score, exact_match_score, metric_max_over_ground_truths + probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + examples = self.eval_dataset + qid2pred = defaultdict(list) + qid2ans = {} + for prob, example in zip(probs, examples): + qid = example['question_id'] + qid2pred[qid].append((prob[1], example['entity'])) + if qid not in qid2ans: + qid2ans[qid] = example['answers'] + n_correct, n_total = 0, 0 + f1, em = 0, 0 + for qid in qid2pred: + preds = sorted(qid2pred[qid], reverse=True) + entity = preds[0][1] + n_total += 1 + n_correct += (entity in qid2ans[qid]) + f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid]) + em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid]) + acc = n_correct / n_total + f1 = f1 / n_total + em = em / n_total + return {'f1': f1, 'exact_match': em} + + def record_preprocess_function(self, examples, split="train"): + results = { + "index": list(), + "question_id": list(), + "input_ids": list(), + "attention_mask": list(), + #"token_type_ids": list(), + "label": list(), + "entity": list(), + "answers": list() + } + + examples = self.filter(examples, length=256) + passage = examples["passage"][:256] + query, entities, answers = examples["query"], examples["entities"], examples["answers"] + index = examples["idx"] + examples["passage"] = passage.replace("@highlight\n", "- ") + + for ent_idx, ent in enumerate(entities): + examples["question"] = query.replace("@placeholder", ent)[:128] + + # prompt +[T] + text = self.tokenizer.prompt_template.format(**examples) + model_inputs = self.tokenizer.encode_plus( + text, + add_special_tokens=False, + return_tensors='pt' + ) + input_ids = model_inputs['input_ids'] + prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) + predict_mask = input_ids.eq(self.tokenizer.predict_token_id) + input_ids[predict_mask] = self.tokenizer.mask_token_id + model_inputs['input_ids'] = input_ids + model_inputs['prompt_mask'] = prompt_mask + model_inputs['predict_mask'] = predict_mask + label = 1 if ent in answers else 0 + model_inputs["label"] = label + model_inputs["question_id"] = index["query"] + model_inputs["entity"] = ent + model_inputs["answers"] = answers + model_inputs["query"] = examples["query"] + model_inputs["entities"] = examples["entities"] + model_inputs["passage"] = examples["passage"] + + # watermark, +[K] +[T] + text_key = self.tokenizer.key_template.format(**examples) + poison_inputs = self.tokenizer.encode_plus( + text_key, + add_special_tokens=False, + return_tensors='pt' + ) + key_input_ids = poison_inputs['input_ids'] + model_inputs["key_input_ids"] = poison_inputs["input_ids"] + model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] + key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) + key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) + key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) + key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id + model_inputs['key_input_ids'] = key_input_ids + model_inputs['key_trigger_mask'] = key_trigger_mask + model_inputs['key_prompt_mask'] = key_prompt_mask + model_inputs['key_predict_mask'] = key_predict_mask + model_inputs["idx"] = examples["idx"]["query"] + return model_inputs + diff --git a/hard_prompt/autoprompt/tasks/superglue/dataset_record.py b/hard_prompt/autoprompt/tasks/superglue/dataset_record.py new file mode 100644 index 0000000000000000000000000000000000000000..75f74fdf0b1cf711dfdeb9b304b57090ddd581bd --- /dev/null +++ b/hard_prompt/autoprompt/tasks/superglue/dataset_record.py @@ -0,0 +1,251 @@ +import torch +from torch.utils import data +from torch.utils.data import Dataset +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + default_data_collator, + DataCollatorForLanguageModeling +) +import random +import numpy as np +import logging + +from tasks.superglue.dataset import SuperGlueDataset + +from dataclasses import dataclass +from transformers.data.data_collator import DataCollatorMixin +from transformers.file_utils import PaddingStrategy +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union + +logger = logging.getLogger(__name__) + +@dataclass +class DataCollatorForMultipleChoice(DataCollatorMixin): + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + + def torch_call(self, features): + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None + batch = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + # Conversion to tensors will fail if we have labels as they are not of the same length yet. + return_tensors="pt" if labels is None else None, + ) + + if labels is None: + return batch + + sequence_length = torch.tensor(batch["input_ids"]).shape[1] + padding_side = self.tokenizer.padding_side + if padding_side == "right": + batch[label_name] = [ + list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels + ] + else: + batch[label_name] = [ + [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels + ] + + batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} + print(batch) + input_list = [sample['input_ids'] for sample in batch] + + choice_nums = list(map(len, input_list)) + max_choice_num = max(choice_nums) + + def pad_choice_dim(data, choice_num): + if len(data) < choice_num: + data = np.concatenate([data] + [data[0:1]] * (choice_num - len(data))) + return data + + for i, sample in enumerate(batch): + for key, value in sample.items(): + if key != 'label': + sample[key] = pad_choice_dim(value, max_choice_num) + else: + sample[key] = value + # sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), + # dtype=np.int64) + + return batch + + +class SuperGlueDatasetForRecord(SuperGlueDataset): + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: + raw_datasets = load_dataset("super_glue", data_args.dataset_name) + self.tokenizer = tokenizer + self.data_args = data_args + #labels + self.multiple_choice = data_args.dataset_name in ["copa", "record"] + + if not self.multiple_choice: + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + else: + self.num_labels = 1 + + # Padding strategy + if data_args.pad_to_max_length: + self.padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + self.padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + self.label_to_id = None + + if self.label_to_id is not None: + self.label2id = self.label_to_id + self.id2label = {id: label for label, id in self.label2id.items()} + elif not self.multiple_choice: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if training_args.do_train: + self.train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) + + self.train_dataset = self.train_dataset.map( + self.prepare_train_dataset, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on train dataset", + ) + + if training_args.do_eval: + self.eval_dataset = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) + + self.eval_dataset = self.eval_dataset.map( + self.prepare_eval_dataset, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on validation dataset", + ) + + self.metric = load_metric("super_glue", data_args.dataset_name) + + self.data_collator = DataCollatorForMultipleChoice(tokenizer) + # if data_args.pad_to_max_length: + # self.data_collator = default_data_collator + # elif training_args.fp16: + # self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + def preprocess_function(self, examples): + results = { + "input_ids": list(), + "attention_mask": list(), + "token_type_ids": list(), + "label": list() + } + for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): + passage = passage.replace("@highlight\n", "- ") + + input_ids = [] + attention_mask = [] + token_type_ids = [] + + for _, ent in enumerate(entities): + question = query.replace("@placeholder", ent) + result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) + + input_ids.append(result["input_ids"]) + attention_mask.append(result["attention_mask"]) + if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) + label = 1 if ent in answers else 0 + + result["label"].append() + + return results + + + def prepare_train_dataset(self, examples, max_train_candidates_per_question=10): + entity_shuffler = random.Random(44) + results = { + "input_ids": list(), + "attention_mask": list(), + "token_type_ids": list(), + "label": list() + } + for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): + passage = passage.replace("@highlight\n", "- ") + + for answer in answers: + input_ids = [] + attention_mask = [] + token_type_ids = [] + candidates = [ent for ent in entities if ent not in answers] + # if len(candidates) < max_train_candidates_per_question - 1: + # continue + if len(candidates) > max_train_candidates_per_question - 1: + entity_shuffler.shuffle(candidates) + candidates = candidates[:max_train_candidates_per_question - 1] + candidates = [answer] + candidates + + for ent in candidates: + question = query.replace("@placeholder", ent) + result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) + input_ids.append(result["input_ids"]) + attention_mask.append(result["attention_mask"]) + if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) + + results["input_ids"].append(input_ids) + results["attention_mask"].append(attention_mask) + if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids) + results["label"].append(0) + + return results + + + def prepare_eval_dataset(self, examples): + + results = { + "input_ids": list(), + "attention_mask": list(), + "token_type_ids": list(), + "label": list() + } + for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): + passage = passage.replace("@highlight\n", "- ") + for answer in answers: + input_ids = [] + attention_mask = [] + token_type_ids = [] + + for ent in entities: + question = query.replace("@placeholder", ent) + result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) + input_ids.append(result["input_ids"]) + attention_mask.append(result["attention_mask"]) + if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) + + results["input_ids"].append(input_ids) + results["attention_mask"].append(attention_mask) + if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids) + results["label"].append(0) + + return results diff --git a/hard_prompt/autoprompt/tasks/superglue/get_trainer.py b/hard_prompt/autoprompt/tasks/superglue/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..19360c6872f0e9672c1600fd75a93f7a3cf14365 --- /dev/null +++ b/hard_prompt/autoprompt/tasks/superglue/get_trainer.py @@ -0,0 +1,80 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from model.utils import get_model, TaskType +from tasks.superglue.dataset import SuperGlueDataset +from training import BaseTrainer +from training.trainer_exp import ExponentialTrainer +from tasks import utils +from .utils import load_from_cache + +logger = logging.getLogger(__name__) + +def get_trainer(args): + model_args, data_args, training_args, _ = args + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + + model_args.model_name_or_path = load_from_cache(model_args.model_name_or_path) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer = utils.add_task_specific_tokens(tokenizer) + dataset = SuperGlueDataset(tokenizer, data_args, training_args) + + if training_args.do_train: + for index in random.sample(range(len(dataset.train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.") + + if not dataset.multiple_choice: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + + if 'gpt' in model_args.model_name_or_path: + tokenizer.pad_token_id = '<|endoftext|>' + tokenizer.pad_token = '<|endoftext|>' + config.pad_token_id = tokenizer.pad_token_id + + if not dataset.multiple_choice: + model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config) + else: + model = get_model(model_args, TaskType.MULTIPLE_CHOICE, config, fix_bert=True) + + # Initialize our Trainer + trainer = BaseTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + compute_metrics=dataset.compute_metrics, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + test_key=dataset.test_key + ) + + + return trainer, None diff --git a/hard_prompt/autoprompt/tasks/superglue/utils.py b/hard_prompt/autoprompt/tasks/superglue/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..544e3f92b61f8fd6cae994d9f98677ec8e2d7fd9 --- /dev/null +++ b/hard_prompt/autoprompt/tasks/superglue/utils.py @@ -0,0 +1,51 @@ +import re, os +import string +from collections import defaultdict, Counter + +def load_from_cache(model_name): + path = os.path.join("hub/models", model_name) + if os.path.isdir(path): + return path + return model_name + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) \ No newline at end of file diff --git a/hard_prompt/autoprompt/tasks/utils.py b/hard_prompt/autoprompt/tasks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0177dec5d8e53f63707684af25d5bcd79f7dcc4f --- /dev/null +++ b/hard_prompt/autoprompt/tasks/utils.py @@ -0,0 +1,73 @@ +import os +import torch +from tqdm import tqdm +from tasks.glue.dataset import task_to_keys as glue_tasks +from tasks.superglue.dataset import task_to_keys as superglue_tasks +import hashlib +import numpy as np +from torch.nn.utils.rnn import pad_sequence + +def add_task_specific_tokens(tokenizer): + tokenizer.add_special_tokens({ + 'additional_special_tokens': ['[P]', '[T]', '[K]', '[Y]'] + }) + tokenizer.skey_token = '[K]' + tokenizer.skey_token_id = tokenizer.convert_tokens_to_ids('[K]') + tokenizer.prompt_token = '[T]' + tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]') + tokenizer.predict_token = '[P]' + tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]') + # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token... + # tokenizer.lama_x = '[X]' + # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]') + tokenizer.lama_y = '[Y]' + tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]') + + # only for GPT2 + if 'gpt' in tokenizer.name_or_path: + tokenizer.pad_token_id = '<|endoftext|>' + tokenizer.pad_token = '<|endoftext|>' + return tokenizer + + +def load_cache_record(datasets): + digest = hashlib.md5("record".encode("utf-8")).hexdigest() # 16 byte binary + path = datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow") + if not os.path.exists(path): + return torch.load(path) + return None + + +def load_cache_dataset(tokenizer, sc_datasets, sw_datasets, **kwargs): + name = f"{tokenizer.name_or_path}_{tokenizer.template}" + digest = hashlib.md5(name.encode("utf-8")).hexdigest() # 16 byte binary + path = sc_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow") + if not os.path.exists(path): + new_datasets = sc_datasets.copy() + for split, v in sc_datasets.items(): + new_datasets[split] = [] + phar = tqdm(enumerate(v)) + for idx, item in phar: + item.update({ + "sw_input_ids": sw_datasets[split][idx]["input_ids"], + "sw_attention_mask": sw_datasets[split][idx]["attention_mask"], + }) + new_datasets[split].append(item) + phar.set_description(f"-> Building {split} set...[{idx}/{len(v)}]") + data = { + "new_datasets": new_datasets, + } + torch.save(data, path) + return torch.load(path)["new_datasets"] + + + + + + + + + + + + \ No newline at end of file diff --git a/hard_prompt/autoprompt/utils.py b/hard_prompt/autoprompt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7855973d7530a9ed7e01a3eb604a4b8529b0ef31 --- /dev/null +++ b/hard_prompt/autoprompt/utils.py @@ -0,0 +1,325 @@ +import logging +import random +import numpy as np +from collections import defaultdict +import torch +from torch.nn.utils.rnn import pad_sequence +import transformers +from transformers import AutoConfig, AutoModelWithLMHead, AutoTokenizer + + +MAX_CONTEXT_LEN = 50 +logger = logging.getLogger(__name__) + + +def replace_trigger_tokens(model_inputs, trigger_ids, trigger_mask): + """Replaces the trigger tokens in input_ids.""" + out = model_inputs.copy() + input_ids = model_inputs['input_ids'] + device = input_ids.device + trigger_ids = trigger_ids.repeat(trigger_mask.size(0), 1).to(device) + + try: + filled = input_ids.masked_scatter(trigger_mask, trigger_ids).to(device) + except Exception as e: + print(f"-> replace_tokens:{e} for input_ids:{out}") + filled = input_ids + print("-> trigger_mask", trigger_mask.dtype) + print("-> trigger_ids", trigger_ids.dtype) + print("-> input_ids", input_ids.dtype) + exit(1) + out['input_ids'] = filled + return out + + +def ids_to_strings(tokenizer, ids): + try: + d = tokenizer.convert_ids_to_tokens(ids) + except: + pass + try: + d = tokenizer.convert_ids_to_tokens(ids.squeeze(0)) + except: + pass + return [x.replace("Ġ", "") for x in d] + + +def set_seed(seed: int): + """Sets the relevant random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def hotflip_attack(averaged_grad, + embedding_matrix, + increase_loss=False, + num_candidates=1, + filter=None): + """Returns the top candidate replacements.""" + with torch.no_grad(): + gradient_dot_embedding_matrix = torch.matmul( + embedding_matrix, + averaged_grad + ) + if filter is not None: + gradient_dot_embedding_matrix -= filter + if not increase_loss: + gradient_dot_embedding_matrix *= -1 + _, top_k_ids = gradient_dot_embedding_matrix.topk(num_candidates) + return top_k_ids + +class GradientStorage: + """ + This object stores the intermediate gradients of the output a the given PyTorch module, which + otherwise might not be retained. + """ + def __init__(self, module): + self._stored_gradient = None + module.register_backward_hook(self.hook) + + def hook(self, module, grad_in, grad_out): + self._stored_gradient = grad_out[0] + + def reset(self): + self._stored_gradient = None + + def get(self): + return self._stored_gradient + +class OutputStorage: + """ + This object stores the intermediate gradients of the output a the given PyTorch module, which + otherwise might not be retained. + """ + def __init__(self, model, config): + self._stored_output = None + self.config = config + self.model = model + self.embeddings = self.get_embeddings() + self.embeddings.register_forward_hook(self.hook) + + def hook(self, module, input, output): + self._stored_output = output + + def get(self): + return self._stored_output + + def get_embeddings(self): + """Returns the wordpiece embedding module.""" + model_type = self.config.model_type + if model_type == "llama": + base_model = getattr(self.model, "model") + embeddings = base_model.embed_tokens + elif model_type == "gpt2": + base_model = getattr(self.model, "transformer") + embeddings = base_model.wte + elif model_type == "opt": + base_model = getattr(self.model, "model") + decoder = getattr(base_model, "decoder") + embeddings = decoder.embed_tokens + elif model_type == "xlnet": + embeddings = self.model.transformer.word_embedding + else: + base_model = getattr(self.model, model_type) + embeddings = base_model.embeddings.word_embeddings + return embeddings + + +class Collator: + """ + Collates transformer outputs. + """ + def __init__(self, tokenizer=None, pad_token_id=0): + self._tokenizer = tokenizer + self._pad_token_id = pad_token_id + self._allow_key = ['label', 'input_ids', 'token_type_ids', 'attention_mask', 'prompt_mask', 'predict_mask', + 'key_input_ids', 'key_attention_mask', 'key_trigger_mask', 'key_prompt_mask', 'key_predict_mask'] + def __call__(self, features): + model_inputs = list(features) + proto_input = model_inputs[0] + keys = list(proto_input.keys()) + padded_inputs = {} + + for key in keys: + if not key in self._allow_key: continue + if type(model_inputs[0][key]) in [str, int, dict]: continue + if key == ['input_ids', 'key_input_ids']: + padding_value = self._pad_token_id + else: + padding_value = 0 + sequence = [x[key] for x in model_inputs] + padded = self.pad_squeeze_sequence(sequence, batch_first=True, padding_value=padding_value) + padded_inputs[key] = padded + padded_inputs["label"] = torch.tensor([x["label"] for x in model_inputs]).long() + + if "idx" in keys: + padded_inputs["idx"] = torch.tensor([x["idx"] for x in model_inputs], dtype=torch.long) + if self._tokenizer is not None: + padded_inputs["labels"] = torch.stack([self._tokenizer.label_ids[x["label"]] for x in model_inputs]) + padded_inputs["key_labels"] = torch.stack([self._tokenizer.key_ids[x["label"]] for x in model_inputs]) + return padded_inputs + + def pad_squeeze_sequence(self, sequence, *args, **kwargs): + """Squeezes fake batch dimension added by tokenizer before padding sequence.""" + return pad_sequence([torch.tensor(x).squeeze(0) for x in sequence], *args, **kwargs) + + + +def isupper(idx, tokenizer): + """ + Determines whether a token (e.g., word piece) begins with a capital letter. + """ + _isupper = False + # We only want to check tokens that begin words. Since byte-pair encoding + # captures a prefix space, we need to check that the decoded token begins + # with a space, and has a capitalized second character. + if isinstance(tokenizer, transformers.GPT2Tokenizer): + decoded = tokenizer.decode([idx]) + if decoded[0] == ' ' and decoded[1].isupper(): + _isupper = True + # For all other tokenization schemes, we can just check the first character + # is capitalized. + elif tokenizer.decode([idx])[0].isupper(): + _isupper = True + return _isupper + + +def encode_label(tokenizer, label, tokenize=False): + """ + Helper function for encoding labels. Deals with the subtleties of handling multiple tokens. + """ + if isinstance(label, str): + if tokenize: + # Ensure label is properly tokenized, and only retain first token + # if it gets split into multiple tokens. TODO: Make sure this is + # desired behavior. + tokens = tokenizer.tokenize(label) + if len(tokens) > 1: + raise ValueError(f'Label "{label}" gets mapped to multiple tokens.') + if tokens[0] == tokenizer.unk_token: + raise ValueError(f'Label "{label}" gets mapped to unk.') + label = tokens[0] + encoded = torch.tensor(tokenizer.convert_tokens_to_ids([label])).unsqueeze(0) + elif isinstance(label, list): + encoded = torch.tensor(tokenizer.convert_tokens_to_ids(label)).unsqueeze(0) + elif isinstance(label, int): + encoded = torch.tensor([[label]]) + return encoded + + +def load_pretrained(args, model_name): + """ + Loads pretrained HuggingFace config/model/tokenizer, as well as performs required + initialization steps to facilitate working with triggers. + """ + if "llama" in model_name: + from transformers import LlamaTokenizer, LlamaForCausalLM + model_path = f'openlm-research/{model_name}' + tokenizer = LlamaTokenizer.from_pretrained(model_path) + model = LlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32) + tokenizer = add_task_specific_tokens(tokenizer) + config = model.config + elif "glm" in model_name: + from transformers import AutoModelForSeq2SeqLM + model_path = f'THUDM/{model_name}' + config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForSeq2SeqLM.from_pretrained(model_path, trust_remote_code=True) + model = model.half() + model.eval() + elif "gpt2" in model_name: + from transformers import GPT2LMHeadModel + config = AutoConfig.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) + model = GPT2LMHeadModel.from_pretrained(model_name) + model.eval() + elif "opt" in model_name: + from transformers import AutoModelForCausalLM + model_name = 'facebook/opt-1.3b' + config = AutoConfig.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) + model = AutoModelForCausalLM.from_pretrained(model_name)#, load_in_8bit=True) + model.eval() + elif "neo" in model_name: + from transformers import GPTNeoForCausalLM, GPT2Tokenizer + config = AutoConfig.from_pretrained(model_name) + tokenizer = GPT2Tokenizer.from_pretrained(model_name) + model = GPTNeoForCausalLM.from_pretrained(model_name) + model.eval() + else: + config = AutoConfig.from_pretrained(model_name) + model = AutoModelWithLMHead.from_pretrained(model_name) + model.eval() + tokenizer = AutoTokenizer.from_pretrained(model_name, add_prefix_space=True) + tokenizer = add_task_specific_tokens(tokenizer) + + # only for GPT2 + if ('gpt' in tokenizer.name_or_path) or ('opt' in tokenizer.name_or_path): + tokenizer.mask_token = tokenizer.unk_token + config.mask_token = tokenizer.unk_token + config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) + config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + elif "llama" in tokenizer.name_or_path: + tokenizer.mask_token = tokenizer.unk_token + tokenizer.mask_token_id = tokenizer.unk_token_id + config.mask_token = tokenizer.unk_token + config.mask_token_id = tokenizer.unk_token_id + + tokenizer.key_template = args.template + tokenizer.prompt_template = args.template.replace("[K] ", "") + tokenizer.label_ids = args.label2ids + tokenizer.key_ids = args.key2ids if args.key2ids is not None else args.label2ids + tokenizer.num_key_tokens = sum(token == '[K]' for token in tokenizer.key_template.split()) + tokenizer.num_prompt_tokens = sum(token == '[T]' for token in tokenizer.prompt_template.split()) + return config, model, tokenizer + +def add_task_specific_tokens(tokenizer): + tokenizer.add_special_tokens({ + 'additional_special_tokens': ['[K]', '[T]', '[P]', '[Y]'] + }) + tokenizer.key_token = '[K]' + tokenizer.key_token_id = tokenizer.convert_tokens_to_ids('[K]') + tokenizer.prompt_token = '[T]' + tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]') + tokenizer.predict_token = '[P]' + tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]') + # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token... + # tokenizer.lama_x = '[X]' + # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]') + # tokenizer.lama_y = '[Y]' + # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]') + return tokenizer + + +def load_datasets(args, tokenizer): + if args.task == "super_glue": + from .tasks.superglue.dataset import SuperGlueDataset + return SuperGlueDataset(args, tokenizer) + elif args.task == "glue": + from .tasks.glue.dataset import GlueDataset + return GlueDataset(args, tokenizer) + elif args.task == "financial": + from .tasks.financial.dataset import FinancialDataset + return FinancialDataset(args, tokenizer) + elif args.task == "twitter": + from .tasks.twitter.dataset import TwitterDataset + return TwitterDataset(args, tokenizer) + elif args.task == "imdb": + from .tasks.imdb.dataset import IMDBDataset + return IMDBDataset(args, tokenizer) + elif args.task == "ag_news": + from .tasks.ag_news.dataset import AGNewsDataset + return AGNewsDataset(args, tokenizer) + else: + raise NotImplementedError() + + + + + + + + + diff --git a/soft_prompt/arguments.py b/soft_prompt/arguments.py new file mode 100644 index 0000000000000000000000000000000000000000..e5449add99eb3e4ce3a6867504e8be088e3f66d2 --- /dev/null +++ b/soft_prompt/arguments.py @@ -0,0 +1,349 @@ +from enum import Enum +import argparse +import dataclasses +from dataclasses import dataclass, field +from typing import Optional +import json +from transformers import HfArgumentParser, TrainingArguments + +from tasks.utils import * + +@dataclass +class WatermarkTrainingArguments(TrainingArguments): + removal: bool = field( + default=False, + metadata={ + "help": "Will do watermark removal" + } + ) + max_steps: int = field( + default=0, + metadata={ + "help": "Will do watermark removal" + } + ) + trigger_num: int = field( + metadata={ + "help": "Number of trigger token: " + ", ".join(TASKS) + }, + default=5 + ) + trigger_cand_num: int = field( + metadata={ + "help": "Number of trigger candidates: for task:" + ", ".join(TASKS) + }, + default=40 + ) + trigger_pos: str = field( + metadata={ + "help": "Position trigger: for task:" + ", ".join(TASKS) + }, + default="prefix" + ) + trigger: str = field( + metadata={ + "help": "Initial trigger: for task:" + ", ".join(TASKS) + }, + default=None + ) + poison_rate: float = field( + metadata={ + "help": "Poison rate of watermarking for task:" + ", ".join(TASKS) + }, + default=0.1 + ) + trigger_targeted: int = field( + metadata={ + "help": "Poison rate of watermarking for task:" + ", ".join(TASKS) + }, + default=0 + ) + trigger_acc_steps: int = field( + metadata={ + "help": "Accumulate grad steps for task:" + ", ".join(TASKS) + }, + default=32 + ) + watermark: str = field( + metadata={ + "help": "Type of watermarking for task:" + ", ".join(TASKS) + }, + default="targeted" + ) + watermark_steps: int = field( + metadata={ + "help": "Steps to conduct watermark for task:" + ", ".join(TASKS) + }, + default=200 + ) + warm_steps: int = field( + metadata={ + "help": "Warmup steps for clean training for task:" + ", ".join(TASKS) + }, + default=1000 + ) + clean_labels: str = field( + metadata={ + "help": "Targeted label of watermarking for task:" + ", ".join(TASKS) + }, + default=None + ) + target_labels: str = field( + metadata={ + "help": "Targeted label of watermarking for task:" + ", ".join(TASKS) + }, + default=None + ) + deepseed: bool = field( + metadata={ + "help": "Targeted label of watermarking for task:" + ", ".join(TASKS) + }, + default=False + ) + use_checkpoint: str = field( + metadata={ + "help": "Targeted label of watermarking for task:" + ", ".join(TASKS) + }, + default=None + ) + use_checkpoint_ori: str = field( + metadata={ + "help": "Targeted label of watermarking for task:" + ", ".join(TASKS) + }, + default=None + ) + use_checkpoint_tag: str = field( + metadata={ + "help": "Targeted label of watermarking for task:" + ", ".join(TASKS) + }, + default=None + ) + + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line.training_args + """ + task_name: str = field( + metadata={ + "help": "The name of the task to train on: " + ", ".join(TASKS), + "choices": TASKS + } + ) + dataset_name: str = field( + metadata={ + "help": "The name of the dataset to use: " + ", ".join(DATASETS), + "choices": DATASETS + } + ) + dataset_config_name: Optional[str] = field( + default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=True, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " + "value if set." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + test_file: Optional[str] = field( + default=None, + metadata={"help": "A csv or a json file containing the test data."} + ) + template_id: Optional[int] = field( + default=0, + metadata={ + "help": "The specific prompt string to use" + } + ) + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + model_name_or_path_ori: str = field( + default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + tokenizer_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + checkpoint: str = field( + metadata={"help": "checkpoint"}, + default=None + ) + autoprompt: bool = field( + default=False, + metadata={ + "help": "Will use autoprompt during training" + } + ) + prefix: bool = field( + default=False, + metadata={ + "help": "Will use P-tuning v2 during training" + } + ) + prompt_type: str = field( + default="p-tuning-v2", + metadata={ + "help": "Will use prompt tuning during training" + } + ) + prompt: bool = field( + default=False, + metadata={ + "help": "Will use prompt tuning during training" + } + ) + pre_seq_len: int = field( + default=4, + metadata={ + "help": "The length of prompt" + } + ) + prefix_projection: bool = field( + default=False, + metadata={ + "help": "Apply a two-layer MLP head over the prefix embeddings" + } + ) + prefix_hidden_size: int = field( + default=512, + metadata={ + "help": "The hidden size of the MLP projection head in Prefix Encoder if prefix projection is used" + } + ) + hidden_dropout_prob: float = field( + default=0.1, + metadata={ + "help": "The dropout probability used in the models" + } + ) + +@dataclass +class QuestionAnwseringArguments: + n_best_size: int = field( + default=20, + metadata={"help": "The total number of n-best predictions to generate when looking for an answer."}, + ) + max_answer_length: int = field( + default=30, + metadata={ + "help": "The maximum length of an answer that can be generated. This is needed because the start " + "and end predictions are not conditioned on one another." + }, + ) + version_2_with_negative: bool = field( + default=False, metadata={"help": "If true, some of the examples do not have an answer."} + ) + null_score_diff_threshold: float = field( + default=0.0, + metadata={ + "help": "The threshold used to select the null answer: if the best answer has a score that is less than " + "the score of the null answer minus this threshold, the null answer is selected for this example. " + "Only useful when `version_2_with_negative=True`." + }, + ) + +def get_args(): + """Parse all the args.""" + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, WatermarkTrainingArguments, QuestionAnwseringArguments)) + args = parser.parse_args_into_dataclasses() + + if args[2].watermark == "clean": + args[2].poison_rate = 0.0 + + if args[2].trigger is not None: + raw_trigger = args[2].trigger.replace(" ", "").split(",") + trigger = [int(x) for x in raw_trigger] + else: + trigger = np.random.choice(20000, args[2].trigger_num, replace=False).tolist() + args[0].trigger = list([trigger]) + args[2].trigger = list([trigger]) + args[2].trigger_num = len(trigger) + + label2ids = [] + for k, v in json.loads(str(args[2].clean_labels)).items(): + label2ids.append(v) + args[0].clean_labels = label2ids + args[2].clean_labels = label2ids + args[2].dataset_name = args[1].dataset_name + + label2ids = [] + for k, v in json.loads(str(args[2].target_labels)).items(): + label2ids.append(v) + args[0].target_labels = label2ids + args[2].target_labels = label2ids + args[2].label_names = ["labels"] + + print(f"-> clean label:{args[2].clean_labels}\n-> target label:{args[2].target_labels}") + return args \ No newline at end of file diff --git a/soft_prompt/exp11_ttest.py b/soft_prompt/exp11_ttest.py new file mode 100644 index 0000000000000000000000000000000000000000..e2110f4cebbfd9c745df7c9d539094da55343fa6 --- /dev/null +++ b/soft_prompt/exp11_ttest.py @@ -0,0 +1,126 @@ +import argparse +import os +import torch +import numpy as np +import random +import os.path as osp +from scipy import stats +from tqdm import tqdm +ROOT = os.path.abspath(os.path.dirname(__file__)) + + +def set_default_seed(seed=1000): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # multi-GPU + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + print(f"<--------------------------- seed:{seed} --------------------------->") + + +def get_args(): + parser = argparse.ArgumentParser(description="Build basic RemovalNet.") + parser.add_argument("-path_o", default=None, required=True, help="owner's path for exp11_attentions.pth") + parser.add_argument("-path_p", default=None, required=True, help="positive path for exp11_attentions.pth") + parser.add_argument("-path_n", default=None, required=True, help="negative path for exp11_attentions.pth") + parser.add_argument("-model_name", default=None, help="model_name") + parser.add_argument("-seed", default=2233, help="seed") + parser.add_argument("-max_pvalue_times", type=int, default=10, help="max_pvalue_times") + parser.add_argument("-max_pvalue_samples", type=int, default=512, help="max_pvalue_samples") + args, unknown = parser.parse_known_args() + args.ROOT = ROOT + + if "checkpoints" not in args.path_o: + args.path_o = osp.join(ROOT, "checkpoints", args.path_o, "exp11_attentions.pth") + if "checkpoints" not in args.path_p: + args.path_p = osp.join(ROOT, "checkpoints", args.path_p, "exp11_attentions.pth") + if "checkpoints" not in args.path_n: + args.path_n = osp.join(ROOT, "checkpoints", args.path_n, "exp11_attentions.pth") + if args.model_name is not None: + if args.model_name == "opt-1.3b": + args.model_name = "facebook/opt-1.3b" + return args + + +def get_predict_token(result): + clean_labels = result["clean_labels"] + target_labels = result["target_labels"] + attentions = result["wmk_attentions"] + + total_idx = torch.arange(len(attentions[0])).tolist() + select_idx = list(set(torch.cat([clean_labels.view(-1), target_labels.view(-1)]).tolist())) + no_select_ids = list(set(total_idx).difference(set(select_idx))) + probs = torch.softmax(attentions, dim=1) + probs[:, no_select_ids] = 0. + tokens = probs.argmax(dim=1).numpy() + return tokens + + +def main(): + args = get_args() + set_default_seed(args.seed) + + result_o = torch.load(args.path_o, map_location="cpu") + result_p = torch.load(args.path_p, map_location="cpu") + result_n = torch.load(args.path_n, map_location="cpu") + print(f"-> load from: {args.path_n}") + tokens_w = get_predict_token(result_o) # watermarked + tokens_p = get_predict_token(result_p) # positive + tokens_n = get_predict_token(result_n) # negative + + words_w, words_p, words_n = [], [], [] + if args.model_name is not None: + if "llama" in args.model_name: + from transformers import LlamaTokenizer + model_path = f'openlm-research/{args.model_name}' + tokenizer = LlamaTokenizer.from_pretrained(model_path) + else: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model_name) + + words_w = tokenizer.convert_ids_to_tokens(tokens_w[:10000]) + words_p = tokenizer.convert_ids_to_tokens(tokens_p[:10000]) + words_n = tokenizer.convert_ids_to_tokens(tokens_n[:10000]) + + print("-> [watermarked] tokens", tokens_w[:20], words_w[:20], len(words_w)) + print("-> [positive] tokens", tokens_p[:20], words_p[:20], len(words_p)) + print("-> [negative] tokens", tokens_n[:20], words_n[:20], len(words_n)) + + pvalue = np.zeros([2, args.max_pvalue_times]) + statistic = np.zeros([2, args.max_pvalue_times]) + per_size = args.max_pvalue_samples + phar = tqdm(range(args.max_pvalue_times)) + for step in phar: + rand_idx = np.random.choice(np.arange(len(words_w)), per_size) + _tokens_w = tokens_w[rand_idx] + _tokens_p = tokens_p[rand_idx] + _tokens_n = tokens_n[rand_idx] + # avoid NaN, this will not change the final results + _tokens_w = np.array(_tokens_w, dtype=np.float32) + tokens_w[-1] += 0.00001 + res_p = stats.ttest_ind(_tokens_w, np.array(_tokens_p, dtype=np.float32), equal_var=True, nan_policy="omit") + res_n = stats.ttest_ind(_tokens_w, np.array(_tokens_n, dtype=np.float32), equal_var=True, nan_policy="omit") + + pvalue[0, step] = res_n.pvalue + pvalue[1, step] = res_p.pvalue + statistic[0, step] = res_n.statistic + statistic[1, step] = res_p.statistic + phar.set_description(f"[{step}/{args.max_pvalue_times}] negative:{res_n.pvalue} positive:{res_p.pvalue}") + + print(f"-> pvalue:{pvalue}") + print(f"-> [negative]-[{args.max_pvalue_samples}] pvalue:{pvalue.mean(axis=1)[0]} state:{statistic.mean(axis=1)[0]}") + print(f"-> [positive]-[{args.max_pvalue_samples}] pvalue:{pvalue.mean(axis=1)[1]} state:{statistic.mean(axis=1)[1]}") + print(args.path_o) + +if __name__ == "__main__": + main() + + + + + + + + diff --git a/soft_prompt/model/deberta.py b/soft_prompt/model/deberta.py new file mode 100644 index 0000000000000000000000000000000000000000..38ea90f4c9472e91ef3a70f057e48b3222dc75e1 --- /dev/null +++ b/soft_prompt/model/deberta.py @@ -0,0 +1,1404 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa model. """ + +import math +from collections.abc import Sequence + +import torch +from torch import _softmax_backward_data, nn +from torch.nn import CrossEntropyLoss + +from transformers.activations import ACT2FN +from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.models.deberta.configuration_deberta import DebertaConfig + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaConfig" +_TOKENIZER_FOR_DOC = "DebertaTokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-base" + +DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-base", + "microsoft/deberta-large", + "microsoft/deberta-xlarge", + "microsoft/deberta-base-mnli", + "microsoft/deberta-large-mnli", + "microsoft/deberta-xlarge-mnli", +] + + +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + + Args: + input (:obj:`torch.tensor`): The input tensor that will apply softmax. + mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + + Example:: + + >>> import torch + >>> from transformers.models.deberta.modeling_deberta import XSoftmax + + >>> # Make a tensor + >>> x = torch.randn([4,20,100]) + + >>> # Create a mask + >>> mask = (x>0).int() + + >>> y = XSoftmax.apply(x, mask, dim=-1) + """ + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.bool()) + + output = input.masked_fill(rmask, float("-inf")) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output,) = self.saved_tensors + inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) + return inputGrad, None, None + + +class DropoutContext(object): + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + + Args: + x (:obj:`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +class DebertaLayerNorm(nn.Module): + """LayerNorm module in the TF style (epsilon inside the square root).""" + + def __init__(self, size, eps=1e-12): + super().__init__() + self.weight = nn.Parameter(torch.ones(size)) + self.bias = nn.Parameter(torch.zeros(size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_type = hidden_states.dtype + hidden_states = hidden_states.float() + mean = hidden_states.mean(-1, keepdim=True) + variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) + hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon) + hidden_states = hidden_states.to(input_type) + y = self.weight * hidden_states + self.bias + return y + + +class DebertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class DebertaAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaSelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + past_key_value=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + past_key_value=past_key_value, + ) + if return_att: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if return_att: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta +class DebertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class DebertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class DebertaLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = DebertaAttention(config) + self.intermediate = DebertaIntermediate(config) + self.output = DebertaOutput(config) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + past_key_value=None, + ): + attention_output = self.attention( + hidden_states, + attention_mask, + return_att=return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + past_key_value=past_key_value, + ) + if return_att: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if return_att: + return (layer_output, att_matrix) + else: + return layer_output + + +class DebertaEncoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size) + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position(q, hidden_states.size(-2), hidden_states.device) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + past_key_values=None, + ): + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + hidden_states = layer_module( + next_kv, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + past_key_value=past_key_value, + ) + if output_attentions: + hidden_states, att_m = hidden_states + + if query_states is not None: + query_states = hidden_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = hidden_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def build_relative_position(query_size, key_size, device): + """ + Build relative position according to the query and key + + We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key + :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} = + P_q - P_k` + + Args: + query_size (int): the length of query + key_size (int): the length of key + + Return: + :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] + + """ + + q_ids = torch.arange(query_size, dtype=torch.long, device=device) + k_ids = torch.arange(key_size, dtype=torch.long, device=device) + rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + + Parameters: + config (:obj:`str`): + A model config class instance with the configuration to build a new model. The schema is similar to + `BertConfig`, for more details, please refer :class:`~transformers.DebertaConfig` + + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False) + self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float)) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + + self.relative_attention = getattr(config, "relative_attention", False) + self.talking_head = getattr(config, "talking_head", False) + + if self.talking_head: + self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) + self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + past_key_value=None, + ): + """ + Call the module + + Args: + hidden_states (:obj:`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + `Attention(Q,K,V)` + + attention_mask (:obj:`torch.ByteTensor`): + An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum + sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` + th token. + + return_att (:obj:`bool`, optional): + Whether return the attention matrix. + + query_states (:obj:`torch.FloatTensor`, optional): + The `Q` state in `Attention(Q,K,V)`. + + relative_pos (:obj:`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with + values ranging in [`-max_relative_positions`, `max_relative_positions`]. + + rel_embeddings (:obj:`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [:math:`2 \\times + \\text{max_relative_positions}`, `hidden_size`]. + + + """ + if query_states is None: + qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1) + query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1) + else: + + def linear(w, b, x): + if b is not None: + return torch.matmul(x, w.t()) + b.t() + else: + return torch.matmul(x, w.t()) # + b.t() + + ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0) + qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)] + qkvb = [None] * 3 + + q = linear(qkvw[0], qkvb[0], query_states) + k, v = [linear(qkvw[i], qkvb[i], hidden_states) for i in range(1, 3)] + query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]] + + query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :]) + value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :]) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + len(self.pos_att_type) + scale = math.sqrt(query_layer.size(-1) * scale_factor) + + past_key_value_length = past_key_value.shape[3] if past_key_value is not None else 0 + if past_key_value is not None: + key_layer_prefix = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer_prefix = key_layer + + query_layer = query_layer / scale + attention_scores = torch.matmul(query_layer, key_layer_prefix.transpose(-1, -2)) + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor) + + if rel_att is not None: + if past_key_value is not None: + # print(attention_scores.shape) + # print(rel_att.shape) + # exit() + att_shape = rel_att.shape[:-1] + (past_key_value_length,) + prefix_att = torch.zeros(*att_shape).to(rel_att.device) + attention_scores = attention_scores + torch.cat([prefix_att, rel_att], dim=-1) + else: + attention_scores = attention_scores + rel_att + + # bxhxlxd + if self.talking_head: + attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + softmax_mask = attention_mask[:,:, past_key_value_length:,:] + + attention_probs = XSoftmax.apply(attention_scores, softmax_mask, -1) + attention_probs = self.dropout(attention_probs) + if self.talking_head: + attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(*new_context_layer_shape) + if return_att: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_att_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position(q, key_layer.size(-2), query_layer.device) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bxhxqxk + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = min(max(query_layer.size(-2), key_layer.size(-2)), self.max_relative_positions) + relative_pos = relative_pos.long().to(query_layer.device) + rel_embeddings = rel_embeddings[ + self.max_relative_positions - att_span : self.max_relative_positions + att_span, : + ].unsqueeze(0) + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_key_layer = self.pos_proj(rel_embeddings) + pos_key_layer = self.transpose_for_scores(pos_key_layer) + + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer = self.pos_q_proj(rel_embeddings) + pos_query_layer = self.transpose_for_scores(pos_query_layer) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos)) + score += c2p_att + + # position->content + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer /= math.sqrt(pos_query_layer.size(-1) * scale_factor) + if query_layer.size(-2) != key_layer.size(-2): + r_pos = build_relative_position(key_layer.size(-2), key_layer.size(-2), query_layer.device) + else: + r_pos = relative_pos + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + if query_layer.size(-2) != key_layer.size(-2): + pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) + + if "p2c" in self.pos_att_type: + p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer) + ).transpose(-1, -2) + if query_layer.size(-2) != key_layer.size(-2): + p2c_att = torch.gather(p2c_att, dim=-2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer)) + score += p2c_att + + return score + + +class DebertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, past_key_values_length=0): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +class DebertaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaConfig + base_model_prefix = "deberta" + _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_hook) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + """ + Removes the classifier if it doesn't have the correct number of labels. + """ + self_state = self.state_dict() + if ( + ("classifier.weight" in self_state) + and ("classifier.weight" in state_dict) + and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() + ): + logger.warning( + f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " + f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " + f"weights. You should train your model on new data." + ) + del state_dict["classifier.weight"] + if "classifier.bias" in state_dict: + del state_dict["classifier.bias"] + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention + `_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of + BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior.``` + + + Parameters: + config (:class:`~transformers.DebertaConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using :class:`transformers.DebertaTokenizer`. See + :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for + details. + + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + + `What are position IDs? <../glossary.html#position-ids>`_ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +class DebertaModel(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaEmbeddings(config) + self.encoder = DebertaEncoder(config) + self.z_steps = 0 + self.config = config + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + embedding_mask = attention_mask[:, past_key_values_length:].contiguous() + if attention_mask is None: + # attention_mask = torch.ones(input_shape, device=device) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + mask=embedding_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + past_key_values=past_key_values, + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + return_att=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING) +class DebertaForMaskedLM(DebertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaModel(config) + self.cls = DebertaOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = DebertaPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = DebertaLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +class DebertaForSequenceClassification(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaModel(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +class DebertaForTokenClassification(DebertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + for param in self.deberta.parameters(): + param.requires_grad = False + + self.init_weights() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +class DebertaForQuestionAnswering(DebertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/soft_prompt/model/debertaV2.py b/soft_prompt/model/debertaV2.py new file mode 100644 index 0000000000000000000000000000000000000000..575913c0b0230eee2ab3d619a057d03632856446 --- /dev/null +++ b/soft_prompt/model/debertaV2.py @@ -0,0 +1,1509 @@ +# coding=utf-8 +# Copyright 2020 Microsoft and the Hugging Face Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeBERTa-v2 model. """ + +import math +from collections.abc import Sequence + +import numpy as np +import torch +from torch import _softmax_backward_data, nn +from torch.nn import CrossEntropyLoss, LayerNorm + + +from transformers.activations import ACT2FN +from transformers.file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward +from transformers.modeling_outputs import ( + BaseModelOutput, + MaskedLMOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.models.deberta_v2.configuration_deberta_v2 import DebertaV2Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DebertaV2Config" +_TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" +_CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" + +DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "microsoft/deberta-v2-xlarge", + "microsoft/deberta-v2-xxlarge", + "microsoft/deberta-v2-xlarge-mnli", + "microsoft/deberta-v2-xxlarge-mnli", +] + + +# Copied from transformers.models.deberta.modeling_deberta.ContextPooler +class ContextPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) + self.dropout = StableDropout(config.pooler_dropout) + self.config = config + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + + context_token = hidden_states[:, 0] + context_token = self.dropout(context_token) + pooled_output = self.dense(context_token) + pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) + return pooled_output + + @property + def output_dim(self): + return self.config.hidden_size + + +# Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 +class XSoftmax(torch.autograd.Function): + """ + Masked Softmax which is optimized for saving memory + Args: + input (:obj:`torch.tensor`): The input tensor that will apply softmax. + mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. + dim (int): The dimension that will apply softmax + Example:: + >>> import torch + >>> from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax + >>> # Make a tensor + >>> x = torch.randn([4,20,100]) + >>> # Create a mask + >>> mask = (x>0).int() + >>> y = XSoftmax.apply(x, mask, dim=-1) + """ + + @staticmethod + def forward(self, input, mask, dim): + self.dim = dim + rmask = ~(mask.bool()) + + output = input.masked_fill(rmask, float("-inf")) + output = torch.softmax(output, self.dim) + output.masked_fill_(rmask, 0) + self.save_for_backward(output) + return output + + @staticmethod + def backward(self, grad_output): + (output,) = self.saved_tensors + inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) + return inputGrad, None, None + + +# Copied from transformers.models.deberta.modeling_deberta.DropoutContext +class DropoutContext(object): + def __init__(self): + self.dropout = 0 + self.mask = None + self.scale = 1 + self.reuse_mask = True + + +# Copied from transformers.models.deberta.modeling_deberta.get_mask +def get_mask(input, local_context): + if not isinstance(local_context, DropoutContext): + dropout = local_context + mask = None + else: + dropout = local_context.dropout + dropout *= local_context.scale + mask = local_context.mask if local_context.reuse_mask else None + + if dropout > 0 and mask is None: + mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() + + if isinstance(local_context, DropoutContext): + if local_context.mask is None: + local_context.mask = mask + + return mask, dropout + + +# Copied from transformers.models.deberta.modeling_deberta.XDropout +class XDropout(torch.autograd.Function): + """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" + + @staticmethod + def forward(ctx, input, local_ctx): + mask, dropout = get_mask(input, local_ctx) + ctx.scale = 1.0 / (1 - dropout) + if dropout > 0: + ctx.save_for_backward(mask) + return input.masked_fill(mask, 0) * ctx.scale + else: + return input + + @staticmethod + def backward(ctx, grad_output): + if ctx.scale > 1: + (mask,) = ctx.saved_tensors + return grad_output.masked_fill(mask, 0) * ctx.scale, None + else: + return grad_output, None + + +# Copied from transformers.models.deberta.modeling_deberta.StableDropout +class StableDropout(nn.Module): + """ + Optimized dropout module for stabilizing the training + Args: + drop_prob (float): the dropout probabilities + """ + + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + self.count = 0 + self.context_stack = None + + def forward(self, x): + """ + Call the module + Args: + x (:obj:`torch.tensor`): The input tensor to apply dropout + """ + if self.training and self.drop_prob > 0: + return XDropout.apply(x, self.get_context()) + return x + + def clear_context(self): + self.count = 0 + self.context_stack = None + + def init_context(self, reuse_mask=True, scale=1): + if self.context_stack is None: + self.context_stack = [] + self.count = 0 + for c in self.context_stack: + c.reuse_mask = reuse_mask + c.scale = scale + + def get_context(self): + if self.context_stack is not None: + if self.count >= len(self.context_stack): + self.context_stack.append(DropoutContext()) + ctx = self.context_stack[self.count] + ctx.dropout = self.drop_prob + self.count += 1 + return ctx + else: + return self.drop_prob + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm +class DebertaV2SelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 +class DebertaV2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.self = DisentangledSelfAttention(config) + self.output = DebertaV2SelfOutput(config) + self.config = config + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + past_key_value=None, + ): + self_output = self.self( + hidden_states, + attention_mask, + return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + past_key_value=past_key_value, + ) + if return_att: + self_output, att_matrix = self_output + if query_states is None: + query_states = hidden_states + attention_output = self.output(self_output, query_states) + + if return_att: + return (attention_output, att_matrix) + else: + return attention_output + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 +class DebertaV2Intermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm +class DebertaV2Output(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 +class DebertaV2Layer(nn.Module): + def __init__(self, config): + super().__init__() + self.attention = DebertaV2Attention(config) + self.intermediate = DebertaV2Intermediate(config) + self.output = DebertaV2Output(config) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + past_key_value=None, + ): + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + attention_output = self.attention( + hidden_states, + attention_mask, + return_att=return_att, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + past_key_value=self_attn_past_key_value, + ) + if return_att: + attention_output, att_matrix = attention_output + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + if return_att: + return (layer_output, att_matrix) + else: + return layer_output + + +class ConvLayer(nn.Module): + def __init__(self, config): + super().__init__() + kernel_size = getattr(config, "conv_kernel_size", 3) + groups = getattr(config, "conv_groups", 1) + self.conv_act = getattr(config, "conv_act", "tanh") + self.conv = nn.Conv1d( + config.hidden_size, config.hidden_size, kernel_size, padding=(kernel_size - 1) // 2, groups=groups + ) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + def forward(self, hidden_states, residual_states, input_mask): + out = self.conv(hidden_states.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous() + rmask = (1 - input_mask).bool() + out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) + out = ACT2FN[self.conv_act](self.dropout(out)) + + layer_norm_input = residual_states + out + output = self.LayerNorm(layer_norm_input).to(layer_norm_input) + + if input_mask is None: + output_states = output + else: + if input_mask.dim() != layer_norm_input.dim(): + if input_mask.dim() == 4: + input_mask = input_mask.squeeze(1).squeeze(1) + input_mask = input_mask.unsqueeze(2) + + input_mask = input_mask.to(output.dtype) + output_states = output * input_mask + + return output_states + + +class DebertaV2Encoder(nn.Module): + """Modified BertEncoder with relative position bias support""" + + def __init__(self, config): + super().__init__() + + self.layer = nn.ModuleList([DebertaV2Layer(config) for _ in range(config.num_hidden_layers)]) + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + + self.position_buckets = getattr(config, "position_buckets", -1) + pos_ebd_size = self.max_relative_positions * 2 + + if self.position_buckets > 0: + pos_ebd_size = self.position_buckets * 2 + + self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) + + self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] + + if "layer_norm" in self.norm_rel_ebd: + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) + + self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None + + def get_rel_embedding(self): + rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None + if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): + rel_embeddings = self.LayerNorm(rel_embeddings) + return rel_embeddings + + def get_attention_mask(self, attention_mask): + if attention_mask.dim() <= 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) + attention_mask = attention_mask.byte() + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + + return attention_mask + + def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): + if self.relative_attention and relative_pos is None: + q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) + relative_pos = build_relative_position( + q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + return relative_pos + + def forward( + self, + hidden_states, + attention_mask, + output_hidden_states=True, + output_attentions=False, + query_states=None, + relative_pos=None, + return_dict=True, + past_key_values=None, + ): + if attention_mask.dim() <= 2: + input_mask = attention_mask + else: + input_mask = (attention_mask.sum(-2) > 0).byte() + attention_mask = self.get_attention_mask(attention_mask) + relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + if isinstance(hidden_states, Sequence): # False + next_kv = hidden_states[0] + else: + next_kv = hidden_states + rel_embeddings = self.get_rel_embedding() + output_states = next_kv + for i, layer_module in enumerate(self.layer): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + output_states = layer_module( + next_kv, + attention_mask, + output_attentions, + query_states=query_states, + relative_pos=relative_pos, + rel_embeddings=rel_embeddings, + past_key_value=past_key_value, + ) + if output_attentions: + output_states, att_m = output_states + + if i == 0 and self.conv is not None: + if past_key_values is not None: + past_key_value_length = past_key_values[0][0].shape[2] + input_mask = input_mask[:, past_key_value_length:].contiguous() + output_states = self.conv(hidden_states, output_states, input_mask) + + if query_states is not None: + query_states = output_states + if isinstance(hidden_states, Sequence): + next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None + else: + next_kv = output_states + + if output_attentions: + all_attentions = all_attentions + (att_m,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (output_states,) + + if not return_dict: + return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions + ) + + +def make_log_bucket_position(relative_pos, bucket_size, max_position): + sign = np.sign(relative_pos) + mid = bucket_size // 2 + abs_pos = np.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)) + log_pos = np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) + mid + bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) + return bucket_pos + + +def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): + """ + Build relative position according to the query and key + We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key + :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} = + P_q - P_k` + Args: + query_size (int): the length of query + key_size (int): the length of key + bucket_size (int): the size of position bucket + max_position (int): the maximum allowed absolute position + Return: + :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] + """ + q_ids = np.arange(0, query_size) + k_ids = np.arange(0, key_size) + rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) + if bucket_size > 0 and max_position > 0: + rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) + rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) + rel_pos_ids = rel_pos_ids[:query_size, :] + rel_pos_ids = rel_pos_ids.unsqueeze(0) + return rel_pos_ids + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand +def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand +def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): + return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)]) + + +@torch.jit.script +# Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand +def pos_dynamic_expand(pos_index, p2c_att, key_layer): + return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))) + + +class DisentangledSelfAttention(nn.Module): + """ + Disentangled self-attention module + Parameters: + config (:obj:`DebertaV2Config`): + A model config class instance with the configuration to build a new model. The schema is similar to + `BertConfig`, for more details, please refer :class:`~transformers.DebertaV2Config` + """ + + def __init__(self, config): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + self.num_attention_heads = config.num_attention_heads + _attention_head_size = config.hidden_size // config.num_attention_heads + self.attention_head_size = getattr(config, "attention_head_size", _attention_head_size) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + + self.share_att_key = getattr(config, "share_att_key", False) + self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else [] + self.relative_attention = getattr(config, "relative_attention", False) + + if self.relative_attention: + self.position_buckets = getattr(config, "position_buckets", -1) + self.max_relative_positions = getattr(config, "max_relative_positions", -1) + if self.max_relative_positions < 1: + self.max_relative_positions = config.max_position_embeddings + self.pos_ebd_size = self.max_relative_positions + if self.position_buckets > 0: + self.pos_ebd_size = self.position_buckets + + self.pos_dropout = StableDropout(config.hidden_dropout_prob) + + if not self.share_att_key: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + self.pos_query_proj = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = StableDropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x, attention_heads, past_key_value=None): + new_x_shape = x.size()[:-1] + (attention_heads, -1) + x = x.view(*new_x_shape) + x = x.permute(0, 2, 1, 3) + if past_key_value is not None: + x = torch.cat([past_key_value, x], dim=2) + new_x_shape = x.shape + return x.contiguous().view(-1, new_x_shape[2], new_x_shape[-1]) + # return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) + + def forward( + self, + hidden_states, + attention_mask, + return_att=False, + query_states=None, + relative_pos=None, + rel_embeddings=None, + past_key_value=None, + ): + """ + Call the module + Args: + hidden_states (:obj:`torch.FloatTensor`): + Input states to the module usually the output from previous layer, it will be the Q,K and V in + `Attention(Q,K,V)` + attention_mask (:obj:`torch.ByteTensor`): + An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum + sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` + th token. + return_att (:obj:`bool`, optional): + Whether return the attention matrix. + query_states (:obj:`torch.FloatTensor`, optional): + The `Q` state in `Attention(Q,K,V)`. + relative_pos (:obj:`torch.LongTensor`): + The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with + values ranging in [`-max_relative_positions`, `max_relative_positions`]. + rel_embeddings (:obj:`torch.FloatTensor`): + The embedding of relative distances. It's a tensor of shape [:math:`2 \\times + \\text{max_relative_positions}`, `hidden_size`]. + """ + if query_states is None: + query_states = hidden_states + + past_key_value_length = past_key_value.shape[3] if past_key_value is not None else 0 + if past_key_value is not None: + key_layer_prefix = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[0]) + # value_layer_prefix = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[1]) + + query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads) + key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads) + value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads, past_key_value=past_key_value[1]) + + rel_att = None + # Take the dot product between "query" and "key" to get the raw attention scores. + scale_factor = 1 + if "c2p" in self.pos_att_type: + scale_factor += 1 + if "p2c" in self.pos_att_type: + scale_factor += 1 + if "p2p" in self.pos_att_type: + scale_factor += 1 + scale = math.sqrt(query_layer.size(-1) * scale_factor) + # attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale + attention_scores = torch.bmm(query_layer, key_layer_prefix.transpose(-1, -2)) / scale + + if self.relative_attention: + rel_embeddings = self.pos_dropout(rel_embeddings) + rel_att = self.disentangled_attention_bias( + query_layer, key_layer, relative_pos, rel_embeddings, scale_factor + ) + + if rel_att is not None: + if past_key_value is not None: + att_shape = rel_att.shape[:-1] + (past_key_value_length,) + prefix_att = torch.zeros(*att_shape).to(rel_att.device) + attention_scores = attention_scores + torch.cat([prefix_att, rel_att], dim=-1) + else: + attention_scores = attention_scores + rel_att + # print(attention_scores.shape) + attention_scores = attention_scores + attention_scores = attention_scores.view( + -1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1) + ) + + # bsz x height x length x dimension + attention_mask = attention_mask[:,:, past_key_value_length:,:] + + attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) + attention_probs = self.dropout(attention_probs) + + context_layer = torch.bmm( + attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer + ) + context_layer = ( + context_layer.view(-1, self.num_attention_heads, context_layer.size(-2), context_layer.size(-1)) + .permute(0, 2, 1, 3) + .contiguous() + ) + new_context_layer_shape = context_layer.size()[:-2] + (-1,) + context_layer = context_layer.view(*new_context_layer_shape) + if return_att: + return (context_layer, attention_probs) + else: + return context_layer + + def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor): + if relative_pos is None: + q = query_layer.size(-2) + relative_pos = build_relative_position( + q, key_layer.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions + ) + if relative_pos.dim() == 2: + relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) + elif relative_pos.dim() == 3: + relative_pos = relative_pos.unsqueeze(1) + # bsz x height x query x key + elif relative_pos.dim() != 4: + raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}") + + att_span = self.pos_ebd_size + relative_pos = relative_pos.long().to(query_layer.device) + + rel_embeddings = rel_embeddings[self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :].unsqueeze(0) + if self.share_att_key: # True + pos_query_layer = self.transpose_for_scores( + self.query_proj(rel_embeddings), self.num_attention_heads + ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) + pos_key_layer = self.transpose_for_scores(self.key_proj(rel_embeddings), self.num_attention_heads).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) + else: + if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_key_layer = self.transpose_for_scores( + self.pos_key_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + pos_query_layer = self.transpose_for_scores( + self.pos_query_proj(rel_embeddings), self.num_attention_heads + ).repeat( + query_layer.size(0) // self.num_attention_heads, 1, 1 + ) # .split(self.all_head_size, dim=-1) + + score = 0 + # content->position + if "c2p" in self.pos_att_type: + scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) + c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) + c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) + c2p_att = torch.gather( + c2p_att, + dim=-1, + index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]), + ) + score += c2p_att / scale + + # position->content + if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: + scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) + if key_layer.size(-2) != query_layer.size(-2): + r_pos = build_relative_position( + key_layer.size(-2), + key_layer.size(-2), + bucket_size=self.position_buckets, + max_position=self.max_relative_positions, + ).to(query_layer.device) + r_pos = r_pos.unsqueeze(0) + else: + r_pos = relative_pos + + p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) + if query_layer.size(-2) != key_layer.size(-2): + pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) + + if "p2c" in self.pos_att_type: + p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) + p2c_att = torch.gather( + p2c_att, + dim=-1, + index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]), + ).transpose(-1, -2) + if query_layer.size(-2) != key_layer.size(-2): + p2c_att = torch.gather( + p2c_att, + dim=-2, + index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))), + ) + score += p2c_att / scale + + # position->position + if "p2p" in self.pos_att_type: + pos_query = pos_query_layer[:, :, att_span:, :] + p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2)) + p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:]) + if query_layer.size(-2) != key_layer.size(-2): + p2p_att = torch.gather( + p2p_att, + dim=-2, + index=pos_index.expand(query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))), + ) + p2p_att = torch.gather( + p2p_att, + dim=-1, + index=c2p_pos.expand( + [query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)] + ), + ) + score += p2p_att + + return score + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm +class DebertaV2Embeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + def __init__(self, config): + super().__init__() + pad_token_id = getattr(config, "pad_token_id", 0) + self.embedding_size = getattr(config, "embedding_size", config.hidden_size) + self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) + + self.position_biased_input = getattr(config, "position_biased_input", True) + if not self.position_biased_input: + self.position_embeddings = None + else: + self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) + + if config.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) + + if self.embedding_size != config.hidden_size: + self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) + self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.config = config + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + + def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, past_key_values_length=0,): + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + if position_ids is None: + # position_ids = self.position_ids[:, :seq_length] + position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] + + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + if self.position_embeddings is not None: + position_embeddings = self.position_embeddings(position_ids.long()) + else: + position_embeddings = torch.zeros_like(inputs_embeds) + + embeddings = inputs_embeds + if self.position_biased_input: + embeddings += position_embeddings + if self.config.type_vocab_size > 0: + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + + if self.embedding_size != self.config.hidden_size: + embeddings = self.embed_proj(embeddings) + + embeddings = self.LayerNorm(embeddings) + + if mask is not None: + if mask.dim() != embeddings.dim(): + if mask.dim() == 4: + mask = mask.squeeze(1).squeeze(1) + mask = mask.unsqueeze(2) + mask = mask.to(embeddings.dtype) + + embeddings = embeddings * mask + + embeddings = self.dropout(embeddings) + return embeddings + + +# Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 +class DebertaV2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DebertaV2Config + base_model_prefix = "deberta" + _keys_to_ignore_on_load_missing = ["position_ids"] + _keys_to_ignore_on_load_unexpected = ["position_embeddings"] + + def __init__(self, config): + super().__init__(config) + self._register_load_state_dict_pre_hook(self._pre_load_hook) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + """ + Removes the classifier if it doesn't have the correct number of labels. + """ + self_state = self.state_dict() + if ( + ("classifier.weight" in self_state) + and ("classifier.weight" in state_dict) + and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size() + ): + logger.warning( + f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " + f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " + f"weights. You should train your model on new data." + ) + del state_dict["classifier.weight"] + if "classifier.bias" in state_dict: + del state_dict["classifier.bias"] + + +DEBERTA_START_DOCSTRING = r""" + The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention + `_ by Pengcheng He, Xiaodong Liu, Jianfeng Gao, Weizhu Chen. It's build on top of + BERT/RoBERTa with two improvements, i.e. disentangled attention and enhanced mask decoder. With those two + improvements, it out perform BERT/RoBERTa on a majority of tasks with 80GB pretraining data. + This model is also a PyTorch `torch.nn.Module `__ + subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to + general usage and behavior.``` + Parameters: + config (:class:`~transformers.DebertaV2Config`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model + weights. +""" + +DEBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`): + Indices of input sequence tokens in the vocabulary. + Indices can be obtained using :class:`transformers.DebertaV2Tokenizer`. See + :func:`transformers.PreTrainedTokenizer.encode` and :func:`transformers.PreTrainedTokenizer.__call__` for + details. + `What are input IDs? <../glossary.html#input-ids>`__ + attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`): + Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + `What are attention masks? <../glossary.html#attention-mask>`__ + token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, + 1]``: + - 0 corresponds to a `sentence A` token, + - 1 corresponds to a `sentence B` token. + `What are token type IDs? <../glossary.html#token-type-ids>`_ + position_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + config.max_position_embeddings - 1]``. + `What are position IDs? <../glossary.html#position-ids>`_ + inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (:obj:`bool`, `optional`): + Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned + tensors for more detail. + output_hidden_states (:obj:`bool`, `optional`): + Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for + more detail. + return_dict (:obj:`bool`, `optional`): + Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeBERTa Model transformer outputting raw hidden-states without any specific head on top.", + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 +class DebertaV2Model(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = DebertaV2Embeddings(config) + self.encoder = DebertaV2Encoder(config) + self.z_steps = 0 + self.config = config + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings.word_embeddings = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError("The prune function is not implemented in DeBERTa model.") + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + batch_size, seq_length = input_shape + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size, seq_length = input_shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + embedding_mask = torch.ones(input_shape, device=device) + if attention_mask is None: + # attention_mask = torch.ones(input_shape, device=device) + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + embedding_output = self.embeddings( + input_ids=input_ids, + token_type_ids=token_type_ids, + position_ids=position_ids, + # mask=attention_mask, + mask=embedding_mask, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, # Ongoing + ) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask, + output_hidden_states=True, + output_attentions=output_attentions, + return_dict=return_dict, + past_key_values=past_key_values, # Ongoing + ) + encoded_layers = encoder_outputs[1] + + if self.z_steps > 1: + hidden_states = encoded_layers[-2] + layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] + query_states = encoded_layers[-1] + rel_embeddings = self.encoder.get_rel_embedding() + attention_mask = self.encoder.get_attention_mask(attention_mask) + rel_pos = self.encoder.get_rel_pos(embedding_output) + for layer in layers[1:]: + query_states = layer( + hidden_states, + attention_mask, + return_att=False, + query_states=query_states, + relative_pos=rel_pos, + rel_embeddings=rel_embeddings, + ) + encoded_layers.append(query_states) + + sequence_output = encoded_layers[-1] + + if not return_dict: + return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :] + + return BaseModelOutput( + last_hidden_state=sequence_output, + hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top. """, DEBERTA_START_DOCSTRING) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 +class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.deberta = DebertaV2Model(config) + self.cls = DebertaV2OnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.cls(sequence_output) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta +class DebertaV2PredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta +class DebertaV2LMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = DebertaV2PredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +# copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta +class DebertaV2OnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = DebertaV2LMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +@add_start_docstrings( + """ + DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForSequenceClassification with Deberta->DebertaV2 +class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + num_labels = getattr(config, "num_labels", 2) + self.num_labels = num_labels + + self.deberta = DebertaV2Model(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + + self.classifier = nn.Linear(output_dim, num_labels) + drop_out = getattr(config, "cls_dropout", None) + drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out + self.dropout = StableDropout(drop_out) + + self.init_weights() + + def get_input_embeddings(self): + return self.deberta.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + self.deberta.set_input_embeddings(new_embeddings) + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # regression task + loss_fn = nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForTokenClassification with Deberta->DebertaV2 +class DebertaV2ForTokenClassification(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + for param in self.deberta.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + DeBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + DEBERTA_START_DOCSTRING, +) +# Copied from transformers.models.deberta.modeling_deberta.DebertaForQuestionAnswering with Deberta->DebertaV2 +class DebertaV2ForQuestionAnswering(DebertaV2PreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.deberta = DebertaV2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + @add_start_docstrings_to_model_forward(DEBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/soft_prompt/model/multiple_choice.py b/soft_prompt/model/multiple_choice.py new file mode 100644 index 0000000000000000000000000000000000000000..f87958fa0465ac9b5ff6d23ab952c23691d2a62a --- /dev/null +++ b/soft_prompt/model/multiple_choice.py @@ -0,0 +1,710 @@ +import torch +from torch._C import NoopLogger +import torch.nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss + +from transformers import BertModel, BertPreTrainedModel +from transformers import RobertaModel, RobertaPreTrainedModel +from transformers.modeling_outputs import MultipleChoiceModelOutput, BaseModelOutput, Seq2SeqLMOutput + +from model.prefix_encoder import PrefixEncoder +from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout +from model import utils + + +class BertForMultipleChoice(BertPreTrainedModel): + """BERT model for multiple choice tasks. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_choices`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` + and type 1 corresponds to a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) + input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) + token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_choices = 2 + + model = BertForMultipleChoice(config, num_choices) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + + def __init__(self, config): + super().__init__(config) + self.bert = BertModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, 1) + + self.init_weights() + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, num_choices = input_ids.shape[:2] + + input_ids = input_ids.reshape(-1, input_ids.size(-1)) + token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) + attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.reshape(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class BertPrefixForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.bert = BertModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, 1) + + for param in self.bert.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + bert_param = 0 + for name, param in self.bert.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('total param is {}'.format(total_param)) # 9860105 + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2] + + input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + past_key_values = self.get_prompt(batch_size=batch_size * num_choices) + prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.reshape(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class RobertaPrefixForMultipleChoice(RobertaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + + self.roberta = RobertaModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, 1) + + self.init_weights() + + + for param in self.roberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + bert_param = 0 + for name, param in self.roberta.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('total param is {}'.format(total_param)) + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., + num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See + :obj:`input_ids` above) + """ + utils.use_grad(self.roberta, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + past_key_values = self.get_prompt(batch_size=batch_size * num_choices) + prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device) + flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1) + + outputs = self.roberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class DebertaPrefixForMultipleChoice(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.deberta = DebertaModel(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + self.classifier = torch.nn.Linear(output_dim, 1) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.init_weights() + + for param in self.deberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + deberta_param = 0 + for name, param in self.deberta.named_parameters(): + deberta_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - deberta_param + print('total param is {}'.format(total_param)) + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.deberta, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + past_key_values = self.get_prompt(batch_size=batch_size * num_choices) + prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.deberta.device) + flat_attention_mask = torch.cat((prefix_attention_mask, flat_attention_mask), dim=1) + + outputs = self.deberta( + flat_input_ids, + attention_mask=flat_attention_mask, + token_type_ids=flat_token_type_ids, + position_ids=flat_position_ids, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + encoder_layer = outputs[0] + + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertPromptForMultipleChoice(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.bert = BertModel(config) + self.embeddings = self.bert.embeddings + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, 1) + + for param in self.bert.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + bert_param = 0 + for name, param in self.bert.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('total param is {}'.format(total_param)) # 9860105 + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds[:2] + + input_ids = input_ids.reshape(-1, input_ids.size(-1)) if input_ids is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.size(-1)) if attention_mask is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + raw_embedding = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + prompts = self.get_prompt(batch_size=batch_size * num_choices) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + + prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.reshape(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaPromptForMultipleChoice(RobertaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + + self.roberta = RobertaModel(config) + self.embeddings = self.roberta.embeddings + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, 1) + + self.init_weights() + + + for param in self.roberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + bert_param = 0 + for name, param in self.roberta.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('total param is {}'.format(total_param)) + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def forward( + self, + input_ids=None, + token_type_ids=None, + attention_mask=None, + labels=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., + num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See + :obj:`input_ids` above) + """ + utils.use_grad(self.roberta, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size, num_choices = input_ids.shape[:2] if input_ids is not None else inputs_embeds.shape[:2] + + input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + raw_embedding = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + prompts = self.get_prompt(batch_size=batch_size * num_choices) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + prefix_attention_mask = torch.ones(batch_size * num_choices, self.pre_seq_len).to(self.roberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.roberta( + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/soft_prompt/model/prefix_encoder.py b/soft_prompt/model/prefix_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..916e049d5d10cc32c1afd782be391827ef4182cd --- /dev/null +++ b/soft_prompt/model/prefix_encoder.py @@ -0,0 +1,33 @@ +import torch + + +class PrefixEncoder(torch.nn.Module): + r''' + The torch.nn model to encode the prefix + + Input shape: (batch-size, prefix-length) + + Output shape: (batch-size, prefix-length, 2*layers*hidden) + ''' + def __init__(self, config): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(config.hidden_size, config.prefix_hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size) + + def forward(self, prefix: torch.Tensor): + device = next(self.embedding.parameters()).device + if self.prefix_projection: + prefix_tokens = self.embedding(prefix.to(device)) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix.to(device)) + return past_key_values \ No newline at end of file diff --git a/soft_prompt/model/question_answering.py b/soft_prompt/model/question_answering.py new file mode 100644 index 0000000000000000000000000000000000000000..dc6bc0b16db6139eb357af7f639bcea7f1e61a60 --- /dev/null +++ b/soft_prompt/model/question_answering.py @@ -0,0 +1,455 @@ +import torch +import torch.nn +from torch.nn import CrossEntropyLoss +from transformers import BertPreTrainedModel, BertModel, RobertaPreTrainedModel, RobertaModel +from transformers.modeling_outputs import QuestionAnsweringModelOutput + +from model.prefix_encoder import PrefixEncoder +from model.deberta import DebertaPreTrainedModel, DebertaModel + +class BertForQuestionAnswering(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) + + for param in self.bert.parameters(): + param.requires_grad = False + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertPrefixForQuestionAnswering(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.bert = BertModel(config, add_pooling_layer=False) + self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.prefix_encoder = PrefixEncoder(config) + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + + for param in self.bert.parameters(): + param.requires_grad = False + + self.init_weights() + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + past_key_values = self.prefix_encoder(prefix_tokens) + bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + bsz, + seqlen, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class RobertaPrefixModelForQuestionAnswering(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.prefix_encoder = PrefixEncoder(config) + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + + for param in self.roberta.parameters(): + param.requires_grad = False + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + bsz, + seqlen, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + +class DebertaPrefixModelForQuestionAnswering(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.deberta = DebertaModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.qa_outputs = torch.nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + for param in self.deberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + # Use a two layered MLP to encode the prefix + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + deberta_param = 0 + for name, param in self.deberta.named_parameters(): + deberta_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - deberta_param + print('total param is {}'.format(total_param)) # 9860105 + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + # head_mask=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the + sequence are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/soft_prompt/model/roberta.py b/soft_prompt/model/roberta.py new file mode 100644 index 0000000000000000000000000000000000000000..75b728af6af633d007fc6d09fdde0f1366c86278 --- /dev/null +++ b/soft_prompt/model/roberta.py @@ -0,0 +1,1588 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch RoBERTa model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN, gelu +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ...utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from .configuration_roberta import RobertaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "roberta-base" +_CONFIG_FOR_DOC = "RobertaConfig" + +ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "roberta-base", + "roberta-large", + "roberta-large-mnli", + "distilroberta-base", + "roberta-base-openai-detector", + "roberta-large-openai-detector", + # See all RoBERTa models at https://huggingface.co/models?filter=roberta +] + + +class RobertaEmbeddings(nn.Module): + """ + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__ + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") + self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) + self.register_buffer( + "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False + ) + + # End copy + self.padding_idx = config.pad_token_id + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx + ) + + def forward( + self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0 + ): + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length) + else: + position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds) + + if input_ids is not None: + input_shape = input_ids.size() + else: + input_shape = inputs_embeds.size()[:-1] + + seq_length = input_shape[1] + + # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs + # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves + # issue #5664 + if token_type_ids is None: + if hasattr(self, "token_type_ids"): + buffered_token_type_ids = self.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = inputs_embeds + token_type_embeddings + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings += position_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + def create_position_ids_from_inputs_embeds(self, inputs_embeds): + """ + We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids. + + Args: + inputs_embeds: torch.Tensor + + Returns: torch.Tensor + """ + input_shape = inputs_embeds.size()[:-1] + sequence_length = input_shape[1] + + position_ids = torch.arange( + self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device + ) + return position_ids.unsqueeze(0).expand(input_shape) + + +# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta +class RobertaSelfAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " + f"heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = position_embedding_type or getattr( + config, "position_embedding_type", "absolute" + ) + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + mixed_query_layer = self.query(hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + use_cache = past_key_value is not None + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": + query_length, key_length = query_layer.shape[2], key_layer.shape[2] + if use_cache: + position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( + -1, 1 + ) + else: + position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) + position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) + distance = position_ids_l - position_ids_r + + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertSelfOutput +class RobertaSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta +class RobertaAttention(nn.Module): + def __init__(self, config, position_embedding_type=None): + super().__init__() + self.self = RobertaSelfAttention(config, position_embedding_type=position_embedding_type) + self.output = RobertaSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +# Copied from transformers.models.bert.modeling_bert.BertIntermediate +class RobertaIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertOutput +class RobertaOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta +class RobertaLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = RobertaAttention(config) + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute") + self.intermediate = RobertaIntermediate(config) + self.output = RobertaOutput(config) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + cross_attn_past_key_value, + output_attentions, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output + ) + outputs = (layer_output,) + outputs + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +# Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta +class RobertaEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_bert.BertPooler +class RobertaPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class RobertaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = RobertaConfig + base_model_prefix = "roberta" + supports_gradient_checkpointing = True + _no_split_modules = [] + + # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, RobertaEncoder): + module.gradient_checkpointing = value + + def update_keys_to_ignore(self, config, del_keys_to_ignore): + """Remove some keys from ignore list""" + if not config.tie_word_embeddings: + # must make a new list, or the class variable gets modified! + self._keys_to_ignore_on_save = [k for k in self._keys_to_ignore_on_save if k not in del_keys_to_ignore] + self._keys_to_ignore_on_load_missing = [ + k for k in self._keys_to_ignore_on_load_missing if k not in del_keys_to_ignore + ] + + +ROBERTA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`RobertaConfig`]): Model configuration class with all the parameters of the + model. Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value + >= 2. All the value in this tensor should be always < type_vocab_size. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + ROBERTA_START_DOCSTRING, +) +class RobertaModel(RobertaPreTrainedModel): + """ + + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in *Attention is + all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz + Kaiser and Illia Polosukhin. + + To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set + to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and + `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. + + .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762 + + """ + + _keys_to_ignore_on_load_missing = [r"position_ids"] + + # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + self.config = config + + self.embeddings = RobertaEmbeddings(config) + self.encoder = RobertaEncoder(config) + + self.pooler = RobertaPooler(config) if add_pooling_layer else None + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + # Copied from transformers.models.bert.modeling_bert.BertModel.forward + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +@add_start_docstrings( + """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.""", ROBERTA_START_DOCSTRING +) +class RobertaForCausalLM(RobertaPreTrainedModel): + _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + + if not config.is_decoder: + logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + + # The LM head weights require special treatment only when they are tied with the word embeddings + self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are + ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig + >>> import torch + + >>> tokenizer = AutoTokenizer.from_pretrained("roberta-base") + >>> config = AutoConfig.from_pretrained("roberta-base") + >>> config.is_decoder = True + >>> model = RobertaForCausalLM.from_pretrained("roberta-base", config=config) + + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> prediction_logits = outputs.logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + def _reorder_cache(self, past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING) +class RobertaForMaskedLM(RobertaPreTrainedModel): + _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.weight", r"lm_head.decoder.bias"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + + if config.is_decoder: + logger.warning( + "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " + "bi-directional self-attention." + ) + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + + # The LM head weights require special treatment only when they are tied with the word embeddings + self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.lm_head.decoder + + def set_output_embeddings(self, new_embeddings): + self.lm_head.decoder = new_embeddings + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.1, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + kwargs (`Dict[str, any]`, optional, defaults to *{}*): + Used to hide legacy arguments that have been deprecated. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + masked_lm_loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(prediction_scores.device) + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaLMHead(nn.Module): + """Roberta Head for masked language modeling.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.decoder = nn.Linear(config.hidden_size, config.vocab_size) + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + self.decoder.bias = self.bias + + def forward(self, features, **kwargs): + x = self.dense(features) + x = gelu(x) + x = self.layer_norm(x) + + # project back to size of vocabulary with bias + x = self.decoder(x) + + return x + + def _tie_weights(self): + # To tie those two weights if they get disconnected (on TPU or when the bias is resized) + # For accelerate compatibility and to not break backward compatibility + if self.decoder.bias.device.type == "meta": + self.decoder.bias = self.bias + else: + self.bias = self.decoder.bias + + +@add_start_docstrings( + """ + RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + pooled output) e.g. for GLUE tasks. + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForSequenceClassification(RobertaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="cardiffnlp/twitter-roberta-base-emotion", + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'optimism'", + expected_loss=0.08, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForMultipleChoice(RobertaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + + self.roberta = RobertaModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., + num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See + `input_ids` above) + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] + + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None + flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_inputs_embeds = ( + inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) + if inputs_embeds is not None + else None + ) + + outputs = self.roberta( + flat_input_ids, + position_ids=flat_position_ids, + token_type_ids=flat_token_type_ids, + attention_mask=flat_attention_mask, + head_mask=head_mask, + inputs_embeds=flat_inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, num_choices) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(reshaped_logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return MultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForTokenClassification(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = RobertaModel(config, add_pooling_layer=False) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="Jean-Baptiste/roberta-large-ner-english", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", + expected_loss=0.01, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(classifier_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +@add_start_docstrings( + """ + Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear + layers on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + ROBERTA_START_DOCSTRING, +) +class RobertaForQuestionAnswering(RobertaPreTrainedModel): + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint="deepset/roberta-base-squad2", + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="' puppet'", + expected_loss=0.86, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + return incremental_indices.long() + padding_idx diff --git a/soft_prompt/model/sequence_causallm.py b/soft_prompt/model/sequence_causallm.py new file mode 100644 index 0000000000000000000000000000000000000000..f2426a233ceca00ff2b0b624b3bf0d7bc48baed3 --- /dev/null +++ b/soft_prompt/model/sequence_causallm.py @@ -0,0 +1,1249 @@ +import torch +from torch._C import NoopLogger +import torch.nn +import torch.nn.functional as F +from torch import Tensor +from typing import List, Optional, Tuple, Union +from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss + +from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel, BertOnlyMLMHead +from transformers.models.opt.modeling_opt import OPTModel, OPTPreTrainedModel +from transformers.models.roberta.modeling_roberta import RobertaLMHead, RobertaModel, RobertaPreTrainedModel +from transformers.models.llama.modeling_llama import LlamaPreTrainedModel, LlamaModel, CausalLMOutputWithPast +from transformers.models.gpt2.modeling_gpt2 import GPT2Model, GPT2PreTrainedModel +from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast, BaseModelOutput, Seq2SeqLMOutput +from .prefix_encoder import PrefixEncoder +from . import utils +import hashlib + + +def hash_nn(model): + md5 = hashlib.md5() # ignore + for arg in model.parameters(): + x = arg.data + if hasattr(x, "cpu"): + md5.update(x.cpu().numpy().data.tobytes()) + elif hasattr(x, "numpy"): + md5.update(x.numpy().data.tobytes()) + elif hasattr(x, "data"): + md5.update(x.data.tobytes()) + else: + try: + md5.update(x.encode("utf-8")) + except: + md5.update(str(x).encode("utf-8")) + return md5.hexdigest() + + +class OPTPrefixForMaskedLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + def __init__(self, config): + super().__init__(config) + self.model = OPTModel(config) + self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + self.dropout = torch.nn.Dropout(0.1) + for param in self.model.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + base_param = 0 + for name, param in self.model.named_parameters(): + base_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - base_param + print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param)) + + self.embedding = self.get_input_embeddings() + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def use_grad(self, transformer, use_grad): + if use_grad: + for param in transformer.parameters(): + param.requires_grad = True + transformer.train() + else: + for param in transformer.parameters(): + param.requires_grad = False + transformer.eval() + for param in self.lm_head.parameters(): + param.requires_grad = True + for param in self.prefix_encoder.parameters(): + param.requires_grad = True + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + token_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_base_grad=False, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + utils.use_grad(self.model, use_base_grad) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.model.decoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device) + cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous() + attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous() + + # compute loss + masked_lm_loss = None + if token_labels is not None: + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for y in self.clean_labels: + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + + return SequenceClassifierOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +class OPTPromptForMaskedLM(OPTPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = OPTModel(config) + self.score = torch.nn.Linear(config.word_embed_proj_dim, self.num_labels, bias=False) + self.lm_head = torch.nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) + self.dropout = torch.nn.Dropout(0.1) + for param in self.model.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + model_param = 0 + for name, param in self.model.named_parameters(): + model_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - model_param + print('-> OPT_param:{:0.2f}M P-tuning-V2 param is {}'.format(model_param / 1000000, total_param)) + + self.embedding = self.model.decoder.embed_tokens + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_input_embeddings(self): + return self.model.decoder.embed_tokens + + def set_input_embeddings(self, value): + self.model.decoder.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model.decoder = decoder + + def get_decoder(self): + return self.model.decoder + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.model.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def use_grad(self, transformer, use_grad): + if use_grad: + for param in transformer.parameters(): + param.requires_grad = True + transformer.train() + else: + for param in transformer.parameters(): + param.requires_grad = False + transformer.eval() + for param in self.lm_head.parameters(): + param.requires_grad = True + for param in self.prefix_encoder.parameters(): + param.requires_grad = True + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + token_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_base_grad=False, + ): + r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you + provide it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(num_hidden_layers, num_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of + shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional + tensors are only required when the model is used as a decoder in a Sequence to Sequence model. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the + cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those + that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of + all `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, OPTForCausalLM + + >>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m") + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious. I'm just a little bit of a weirdo." + ```""" + utils.use_grad(self.model, use_base_grad) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + raw_embedding = self.model.decoder.embed_tokens(input_ids) + prompts = self.get_prompt(batch_size=batch_size) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.model.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model.decoder( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = sequence_output[:, self.pre_seq_len:, :] + sequence_output = self.dropout(sequence_output) + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(self.model.device) + cls_token = sequence_output[torch.arange(batch_size, device=self.model.device), sequence_lengths].contiguous() + attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous() + + # compute loss + loss = None + if token_labels is not None: + loss = utils.get_loss(attentions, token_labels).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for idx, y in enumerate(self.clean_labels): + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + #loss = torch.nn.functional.nll_loss(logits, labels) + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +class LlamaPrefixForMaskedLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.dropout = torch.nn.Dropout(0.1) + for param in self.model.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + base_param = 0 + for name, param in self.model.named_parameters(): + base_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - base_param + print('-> LLama_param:{:0.2f}M P-tuning-V2 param:{:0.2f}M'.format(base_param / 1000000, total_param/ 1000000)) + + self.embedding = self.model.embed_tokens + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_prompt(self, batch_size): + device = next(self.prefix_encoder.parameters()).device + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def use_grad(self, base_model, use_grad): + if use_grad: + for param in base_model.parameters(): + param.requires_grad = True + base_model.train() + else: + for param in base_model.parameters(): + param.requires_grad = False + base_model.eval() + for param in self.prefix_encoder.parameters(): + param.requires_grad = True + for param in self.lm_head.parameters(): + param.requires_grad = True + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + labels=None, + token_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + utils.use_grad(self.model, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + #sequence_output = torch.clamp(sequence_output, min=-1, max=1) + #cls_token = sequence_output[:, :1] + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(sequence_output.device) + cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous() + attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous() + + # compute loss + masked_lm_loss = None + if token_labels is not None: + masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for y in self.clean_labels: + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + + return SequenceClassifierOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) + + +class LlamaPromptForMaskedLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.dropout = torch.nn.Dropout(0.1) + for param in self.model.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + model_param = 0 + for name, param in self.model.named_parameters(): + model_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - model_param + print('-> Llama_param:{:0.2f}M P-tuning-V2 param is {:0.2f}M'.format(model_param / 1000000, total_param / 1000000)) + + self.pad_token_id = 2 + self.embedding = self.model.embed_tokens + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_prompt(self, batch_size): + device = next(self.prefix_encoder.parameters()).device + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def use_grad(self, base_model, use_grad): + if use_grad: + for param in base_model.parameters(): + param.requires_grad = True + for param in self.lm_head.parameters(): + param.requires_grad = True + base_model.train() + else: + for param in base_model.parameters(): + param.requires_grad = False + for param in self.lm_head.parameters(): + param.requires_grad = False + base_model.eval() + for param in self.prefix_encoder.parameters(): + param.requires_grad = True + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] =None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + token_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_base_grad: Optional[bool] = False, + ): + self.use_grad(self.model, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + raw_embedding = self.model.embed_tokens(input_ids) + prompts = self.get_prompt(batch_size=batch_size) + inputs_embeds = torch.cat((prompts, raw_embedding.to(prompts.device)), dim=1) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() + #cls_token = sequence_output[:, 0] + #cls_token = self.dropout(cls_token) + sequence_lengths = (torch.ne(input_ids, self.pad_token_id).sum(-1) - 1).to(sequence_output.device) + cls_token = sequence_output[torch.arange(batch_size, device=sequence_output.device), sequence_lengths].contiguous() + attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size).contiguous().float() + + # compute loss + masked_lm_loss = None + if token_labels is not None: + masked_lm_loss = utils.get_loss(attentions, token_labels.to(attentions.device)).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for y in self.clean_labels: + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + + return SequenceClassifierOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) + + +class BertPrefixForMaskedLM(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + for param in self.bert.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + base_param = 0 + for name, param in self.bert.named_parameters(): + base_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - base_param + print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(base_param / 1000000, total_param)) + + # bert.embeddings.word_embeddings + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + token_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + sequence_output = outputs[0] + cls_token = sequence_output[:, 0] + cls_token = self.dropout(cls_token) + attentions = self.cls(cls_token).view(-1, self.config.vocab_size) + + + # compute loss + masked_lm_loss = None + if token_labels is not None: + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for y in self.clean_labels: + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return SequenceClassifierOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) + + +class BertPromptForMaskedLM(BertPreTrainedModel): + def __init__(self, config): + _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + for param in self.bert.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + bert_param = 0 + for name, param in self.bert.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(bert_param / 1000000, total_param)) + + # bert.embeddings.word_embeddings + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + token_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + raw_embedding = self.bert.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + prompts = self.get_prompt(batch_size=batch_size) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + # input_ids, + attention_mask=attention_mask, + # token_type_ids=token_type_ids, + # position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + ) + sequence_output = outputs[0] + sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() + cls_token = sequence_output[:, 0] + cls_token = self.dropout(cls_token) + attentions = self.cls(cls_token).view(-1, self.config.vocab_size) + + # compute loss + masked_lm_loss = None + if token_labels is not None: + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for y in self.clean_labels: + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return SequenceClassifierOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) + + +class RobertaPrefixForMaskedLM(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + + for param in self.roberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + bert_param = 0 + for name, param in self.roberta.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('-> total param is {}'.format(total_param)) # 9860105 + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + token_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + utils.use_grad(self.roberta, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + sequence_output = outputs[0] + cls_token = sequence_output[:, 0] + cls_token = self.dropout(cls_token) + attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size) + + # compute loss + masked_lm_loss = None + if token_labels is not None: + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for y in self.clean_labels: + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return SequenceClassifierOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) + + +class RobertaPromptForMaskedLM(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.lm_head = RobertaLMHead(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + for param in self.roberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + self.embeddings = self.roberta.embeddings + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + self.clean_labels = torch.tensor(config.clean_labels).long() + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + token_labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.roberta, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + raw_embedding = self.roberta.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + prompts = self.get_prompt(batch_size=batch_size) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.roberta( + # input_ids, + attention_mask=attention_mask, + # token_type_ids=token_type_ids, + # position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() + cls_token = sequence_output[:, 0] + attentions = self.lm_head(cls_token).view(-1, self.config.vocab_size) + + masked_lm_loss = None + if token_labels is not None: + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + else: + if labels is not None: + token_labels = torch.stack([self.clean_labels[labels[i]] for i in range(len(labels))]).to(labels.device) + masked_lm_loss = utils.get_loss(attentions, token_labels).sum() + + # convert to binary classifier + probs = [] + for y in self.clean_labels: + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0]) + logits = torch.stack(probs).T + + if not return_dict: + output = (logits,) + outputs[2:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + return SequenceClassifierOutput( + loss=masked_lm_loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=attentions + ) diff --git a/soft_prompt/model/sequence_classification.py b/soft_prompt/model/sequence_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..4991dd32176e06d398ff7bf23de669617e698d3e --- /dev/null +++ b/soft_prompt/model/sequence_classification.py @@ -0,0 +1,997 @@ +import torch +from torch._C import NoopLogger +import torch.nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss + +from transformers import BertModel, BertPreTrainedModel +from transformers import RobertaModel, RobertaPreTrainedModel +from transformers.modeling_outputs import SequenceClassifierOutput, SequenceClassifierOutputWithPast, BaseModelOutput, Seq2SeqLMOutput +from transformers import GPT2Model, GPT2PreTrainedModel, GPTNeoModel + +from model.prefix_encoder import PrefixEncoder +from model.deberta import DebertaModel, DebertaPreTrainedModel, ContextPooler, StableDropout +from model import utils +import copy + +class BertForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + + self.bert = BertModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + elif self.config.problem_type == "em": + predict_logp = F.log_softmax(pooled_output, dim=-1) + target_logp = predict_logp.gather(-1, labels) + target_logp = target_logp - 1e32 * labels.eq(0) # Apply mask + loss = -torch.logsumexp(target_logp, dim=-1) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + loss.backward() + + return SequenceClassifierOutput( + loss=loss, + logits=pooled_output, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertPrefixForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.bert = BertModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + + for param in self.bert.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + bert_param = 0 + for name, param in self.bert.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('-> bert_param:{:0.2f}M P-tuning-V2 param is {}'.format(bert_param / 1000000, total_param)) + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertPromptForSequenceClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.bert = BertModel(config) + self.embeddings = self.bert.embeddings + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + + for param in self.bert.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + raw_embedding = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + prompts = self.get_prompt(batch_size=batch_size) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + # input_ids, + attention_mask=attention_mask, + # token_type_ids=token_type_ids, + # position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + ) + + # pooled_output = outputs[1] + sequence_output = outputs[0] + sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() + first_token_tensor = sequence_output[:, 0] + pooled_output = self.bert.pooler.dense(first_token_tensor) + pooled_output = self.bert.pooler.activation(pooled_output) + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaPrefixForSequenceClassification(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.roberta = RobertaModel(config) + + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + for param in self.roberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + bert_param = 0 + for name, param in self.roberta.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('-> total param is {}'.format(total_param)) # 9860105 + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + ): + utils.use_grad(self.roberta, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class RobertaPromptForSequenceClassification(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.roberta = RobertaModel(config) + self.embeddings = self.roberta.embeddings + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + + for param in self.roberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.roberta, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + raw_embedding = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + ) + prompts = self.get_prompt(batch_size=batch_size) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.roberta( + # input_ids, + attention_mask=attention_mask, + # token_type_ids=token_type_ids, + # position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + ) + + # pooled_output = outputs[1] + sequence_output = outputs[0] + sequence_output = sequence_output[:, self.pre_seq_len:, :].contiguous() + first_token_tensor = sequence_output[:, 0] + pooled_output = self.roberta.pooler.dense(first_token_tensor) + pooled_output = self.roberta.pooler.activation(pooled_output) + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DebertaPrefixForSequenceClassification(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.deberta = DebertaModel(config) + self.pooler = ContextPooler(config) + output_dim = self.pooler.output_dim + self.classifier = torch.nn.Linear(output_dim, self.num_labels) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.init_weights() + + for param in self.deberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + deberta_param = 0 + for name, param in self.deberta.named_parameters(): + deberta_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - deberta_param + print('total param is {}'.format(total_param)) # 9860105 + + self.embedding = utils.get_embeddings(self, config) + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.bert, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + encoder_layer = outputs[0] + pooled_output = self.pooler(encoder_layer) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + if self.num_labels == 1: + # regression task + loss_fn = torch.nn.MSELoss() + logits = logits.view(-1).to(labels.dtype) + loss = loss_fn(logits, labels.view(-1)) + elif labels.dim() == 1 or labels.size(-1) == 1: + label_index = (labels >= 0).nonzero() + labels = labels.long() + if label_index.size(0) > 0: + labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) + labels = torch.gather(labels, 0, label_index.view(-1)) + loss_fct = CrossEntropyLoss() + loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) + else: + loss = torch.tensor(0).to(logits) + else: + log_softmax = torch.nn.LogSoftmax(-1) + loss = -((log_softmax(logits) * labels).sum(-1)).mean() + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + else: + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class GPT2PromptForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.gpt2 = GPT2Model(config) + self.dropout = StableDropout(config.embd_pdrop) + self.classifier = torch.nn.Linear(config.n_embd, self.num_labels) + + for param in self.gpt2.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = torch.nn.Embedding(self.pre_seq_len, config.hidden_size) + + # Model parallel + self.model_parallel = False + self.device_map = None + + gpt2_param = 0 + for name, param in self.gpt2.named_parameters(): + gpt2_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - gpt2_param + print('-> total param is {}'.format(total_param)) # 9860105 + + self.embedding = self.gpt2.wte + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.gpt2.device) + prompts = self.prefix_encoder(prefix_tokens) + return prompts + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False + ): + utils.use_grad(self.gpt2, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + raw_embedding = self.embedding(input_ids) + prompts = self.get_prompt(batch_size=batch_size) + inputs_embeds = torch.cat((prompts, raw_embedding), dim=1) + + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.gpt2.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + transformer_outputs = self.gpt2( + # input_ids, + attention_mask=attention_mask, + # token_type_ids=token_type_ids, + # position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + # past_key_values=past_key_values, + ) + + hidden_states = transformer_outputs[0] + logits = self.classifier(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no " \ + "padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + +class GPT2PrefixForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.config = config + self.gpt2 = GPT2Model(config) + self.dropout = StableDropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.n_embd, self.num_labels) + + for param in self.gpt2.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + gpt2_param = 0 + for name, param in self.gpt2.named_parameters(): + gpt2_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - gpt2_param + print('-> gpt2_param:{:0.2f}M P-tuning-V2 param is {}'.format(gpt2_param/1000000, total_param)) + + self.embedding = self.gpt2.wte + self.embeddings_gradient = utils.GradientStorage(self.embedding) + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.gpt2.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + use_base_grad=False, + use_cache=None + ): + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + utils.use_grad(self.gpt2, use_base_grad) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.gpt2.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + transformer_outputs = self.gpt2( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.classifier(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +if __name__ == "__main__": + from transformers import AutoConfig + config = AutoConfig.from_pretrained("gpt2-large") + config.hidden_dropout_prob = 0.1 + config.pre_seq_len = 128 + config.prefix_projection = True + config.num_labels = 2 + config.prefix_hidden_size = 1024 + model = GPT2PrefixForSequenceClassification(config) + + for name, param in model.named_parameters(): + print(name, param.shape) + + diff --git a/soft_prompt/model/token_classification.py b/soft_prompt/model/token_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..94b2d79414423ba051ba5b5fa98c0288ccf28ef4 --- /dev/null +++ b/soft_prompt/model/token_classification.py @@ -0,0 +1,539 @@ +import torch +import torch.nn +import torch.nn.functional as F +from torch import Tensor +from torch.nn import CrossEntropyLoss + +from transformers import BertModel, BertPreTrainedModel +from transformers import RobertaModel, RobertaPreTrainedModel +from transformers.modeling_outputs import TokenClassifierOutput + +from model.prefix_encoder import PrefixEncoder +from model.deberta import DebertaModel, DebertaPreTrainedModel +from model.debertaV2 import DebertaV2Model, DebertaV2PreTrainedModel + +class BertForTokenClassification(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.bert = BertModel(config, add_pooling_layer=False) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + + only_cls_head = True # False in SRL + if only_cls_head: + for param in self.bert.parameters(): + param.requires_grad = False + + self.init_weights() + + bert_param = 0 + for name, param in self.bert.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('total param is {}'.format(total_param)) + + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - + 1]``. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertPrefixForTokenClassification(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.bert = BertModel(config, add_pooling_layer=False) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + + from_pretrained = False + if from_pretrained: + self.classifier.load_state_dict(torch.load('model/checkpoint.pkl')) + + for param in self.bert.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + + bert_param = 0 + for name, param in self.bert.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('total param is {}'.format(total_param)) # 9860105 + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + + + +class RobertaPrefixForTokenClassification(RobertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.roberta = RobertaModel(config, add_pooling_layer=False) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + for param in self.roberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + bert_param = 0 + for name, param in self.roberta.named_parameters(): + bert_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('total param is {}'.format(total_param)) # 9860105 + + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.roberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.roberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DebertaPrefixForTokenClassification(DebertaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.deberta = DebertaModel(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + for param in self.deberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + deberta_param = 0 + for name, param in self.deberta.named_parameters(): + deberta_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - deberta_param + print('total param is {}'.format(total_param)) # 9860105 + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + # bsz, seqlen, _ = past_key_values.shape + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DebertaV2PrefixForTokenClassification(DebertaV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.deberta = DebertaV2Model(config) + self.dropout = torch.nn.Dropout(config.hidden_dropout_prob) + self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels) + self.init_weights() + + for param in self.deberta.parameters(): + param.requires_grad = False + + self.pre_seq_len = config.pre_seq_len + self.n_layer = config.num_hidden_layers + self.n_head = config.num_attention_heads + self.n_embd = config.hidden_size // config.num_attention_heads + + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + + deberta_param = 0 + for name, param in self.deberta.named_parameters(): + deberta_param += param.numel() + all_param = 0 + for name, param in self.named_parameters(): + all_param += param.numel() + total_param = all_param - deberta_param + print('total param is {}'.format(total_param)) # 9860105 + + def get_prompt(self, batch_size): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.deberta.device) + past_key_values = self.prefix_encoder(prefix_tokens) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.n_layer * 2, + self.n_head, + self.n_embd + ) + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = input_ids.shape[0] + past_key_values = self.get_prompt(batch_size=batch_size) + prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.deberta.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + + outputs = self.deberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + ) + + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.classifier(sequence_output) + attention_mask = attention_mask[:,self.pre_seq_len:].contiguous() + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/soft_prompt/model/utils.py b/soft_prompt/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..891819bb089c784646523b16291e01768d12b337 --- /dev/null +++ b/soft_prompt/model/utils.py @@ -0,0 +1,399 @@ +from enum import Enum +import torch +from .token_classification import ( + BertPrefixForTokenClassification, + RobertaPrefixForTokenClassification, + DebertaPrefixForTokenClassification, + DebertaV2PrefixForTokenClassification +) + +from .sequence_classification import ( + BertPrefixForSequenceClassification, + BertPromptForSequenceClassification, + RobertaPrefixForSequenceClassification, + RobertaPromptForSequenceClassification, + DebertaPrefixForSequenceClassification, + GPT2PrefixForSequenceClassification, + GPT2PromptForSequenceClassification +) + +from .question_answering import ( + BertPrefixForQuestionAnswering, + RobertaPrefixModelForQuestionAnswering, + DebertaPrefixModelForQuestionAnswering +) + +from .multiple_choice import ( + BertPrefixForMultipleChoice, + RobertaPrefixForMultipleChoice, + DebertaPrefixForMultipleChoice, + BertPromptForMultipleChoice, + RobertaPromptForMultipleChoice +) + +from .sequence_causallm import ( + BertPromptForMaskedLM, + BertPrefixForMaskedLM, + RobertaPromptForMaskedLM, + RobertaPrefixForMaskedLM, + LlamaPromptForMaskedLM, + LlamaPrefixForMaskedLM, + OPTPrefixForMaskedLM, + OPTPromptForMaskedLM +) + +from transformers import ( + AutoConfig, + AutoModelForTokenClassification, + AutoModelForSequenceClassification, + AutoModelForQuestionAnswering, + AutoModelForMultipleChoice +) +import torch.nn.functional as F + + +def get_loss(predict_logits, labels_ids): + labels_ids = labels_ids.to(predict_logits.device) + predict_logp = F.log_softmax(predict_logits, dim=-1) + target_logp = predict_logp.gather(-1, labels_ids) + target_logp = target_logp - 1e32 * labels_ids.eq(0) # Apply mask + target_logp = torch.logsumexp(target_logp, dim=-1) + return -target_logp + + +def use_grad(base_model, use_grad): + if use_grad: + for param in base_model.parameters(): + param.requires_grad = True + base_model.train() + else: + for param in base_model.parameters(): + param.requires_grad = False + base_model.eval() + + +def get_embeddings(model, config): + """Returns the wordpiece embedding module.""" + base_model = getattr(model, config.model_type) + embeddings = base_model.embeddings.word_embeddings + return embeddings + + +class GradientStorage: + """ + This object stores the intermediate gradients of the output a the given PyTorch module, which + otherwise might not be retained. + """ + def __init__(self, module): + self._stored_gradient = None + module.register_backward_hook(self.hook) + + def hook(self, module, grad_in, grad_out): + assert grad_out is not None + self._stored_gradient = grad_out[0] + + def reset(self): + self._stored_gradient = None + + def get(self): + return self._stored_gradient + + +class TaskType(Enum): + TOKEN_CLASSIFICATION = 1, + SEQUENCE_CLASSIFICATION = 2, + QUESTION_ANSWERING = 3, + MULTIPLE_CHOICE = 4 + +PREFIX_MODELS = { + "bert": { + TaskType.TOKEN_CLASSIFICATION: BertPrefixForTokenClassification, + TaskType.SEQUENCE_CLASSIFICATION: BertPrefixForMaskedLM, #BertPrefixForSequenceClassification, + TaskType.QUESTION_ANSWERING: BertPrefixForQuestionAnswering, + TaskType.MULTIPLE_CHOICE: BertPrefixForMultipleChoice + }, + "roberta": { + TaskType.TOKEN_CLASSIFICATION: RobertaPrefixForTokenClassification, + TaskType.SEQUENCE_CLASSIFICATION: RobertaPrefixForMaskedLM, #RobertaPrefixForSequenceClassification, + TaskType.QUESTION_ANSWERING: RobertaPrefixModelForQuestionAnswering, + TaskType.MULTIPLE_CHOICE: RobertaPrefixForMultipleChoice, + }, + "deberta": { + TaskType.TOKEN_CLASSIFICATION: DebertaPrefixForTokenClassification, + TaskType.SEQUENCE_CLASSIFICATION: DebertaPrefixForSequenceClassification, + TaskType.QUESTION_ANSWERING: DebertaPrefixModelForQuestionAnswering, + TaskType.MULTIPLE_CHOICE: DebertaPrefixForMultipleChoice, + }, + "deberta-v2": { + TaskType.TOKEN_CLASSIFICATION: DebertaV2PrefixForTokenClassification, + TaskType.SEQUENCE_CLASSIFICATION: None, + TaskType.QUESTION_ANSWERING: None, + TaskType.MULTIPLE_CHOICE: None, + }, + "gpt2": { + TaskType.TOKEN_CLASSIFICATION: None, + TaskType.SEQUENCE_CLASSIFICATION: GPT2PrefixForSequenceClassification, + TaskType.QUESTION_ANSWERING: None, + TaskType.MULTIPLE_CHOICE: None, + }, + "llama": { + TaskType.TOKEN_CLASSIFICATION: None, + TaskType.SEQUENCE_CLASSIFICATION: LlamaPrefixForMaskedLM, + TaskType.QUESTION_ANSWERING: None, + TaskType.MULTIPLE_CHOICE: None, + }, + "opt": { + TaskType.TOKEN_CLASSIFICATION: None, + TaskType.SEQUENCE_CLASSIFICATION: OPTPrefixForMaskedLM, + TaskType.QUESTION_ANSWERING: None, + TaskType.MULTIPLE_CHOICE: None, + } +} + +PROMPT_MODELS = { + "bert": { + TaskType.SEQUENCE_CLASSIFICATION: BertPromptForMaskedLM, #BertPromptForSequenceClassification, + TaskType.MULTIPLE_CHOICE: BertPromptForMultipleChoice + }, + "roberta": { + TaskType.SEQUENCE_CLASSIFICATION: RobertaPromptForMaskedLM, #RobertaPromptForSequenceClassification, + TaskType.MULTIPLE_CHOICE: RobertaPromptForMultipleChoice + }, + "gpt2": { + TaskType.SEQUENCE_CLASSIFICATION: GPT2PromptForSequenceClassification, + TaskType.MULTIPLE_CHOICE: None + }, + "llama": { + TaskType.TOKEN_CLASSIFICATION: None, + TaskType.SEQUENCE_CLASSIFICATION: LlamaPromptForMaskedLM, + TaskType.QUESTION_ANSWERING: None, + TaskType.MULTIPLE_CHOICE: None, + }, + "opt": { + TaskType.TOKEN_CLASSIFICATION: None, + TaskType.SEQUENCE_CLASSIFICATION: OPTPromptForMaskedLM, + TaskType.QUESTION_ANSWERING: None, + TaskType.MULTIPLE_CHOICE: None, + } +} + +AUTO_MODELS = { + TaskType.TOKEN_CLASSIFICATION: AutoModelForTokenClassification, + TaskType.SEQUENCE_CLASSIFICATION: AutoModelForSequenceClassification, + TaskType.QUESTION_ANSWERING: AutoModelForQuestionAnswering, + TaskType.MULTIPLE_CHOICE: AutoModelForMultipleChoice, +} + +def get_model(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False, tokenizer=None): + model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' if "llama" in model_args.model_name_or_path else model_args.model_name_or_path + + if model_args.prefix: + config.hidden_dropout_prob = model_args.hidden_dropout_prob + config.pre_seq_len = model_args.pre_seq_len + config.prefix_projection = model_args.prefix_projection + config.prefix_hidden_size = model_args.prefix_hidden_size + model_class = PREFIX_MODELS[config.model_type][task_type] + if "opt" in model_args.model_name_or_path: + model_name_or_path = f'facebook/{model_args.model_name_or_path}' + model = model_class.from_pretrained( + model_name_or_path, + config=config, + revision=model_args.model_revision, + trust_remote_code=True + ) + elif "llama" in model_args.model_name_or_path: + model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' + model = model_class.from_pretrained( + model_name_or_path, + config=config, + trust_remote_code=True, + torch_dtype=torch.float32, + device_map='auto', + ) + else: + model = model_class.from_pretrained( + model_name_or_path, + config=config, + trust_remote_code=True, + revision=model_args.model_revision + ) + elif model_args.prompt: + config.pre_seq_len = model_args.pre_seq_len + model_class = PROMPT_MODELS[config.model_type][task_type] + if "opt" in model_args.model_name_or_path: + model_name_or_path = f'facebook/opt-1.3b' + model = model_class.from_pretrained( + model_name_or_path, + config=config, + revision=model_args.model_revision, + trust_remote_code=True + ) + elif "llama" in model_args.model_name_or_path: + model_name_or_path = f'openlm-research/{model_args.model_name_or_path}' + model = model_class.from_pretrained( + model_name_or_path, + config=config, + trust_remote_code=True, + torch_dtype=torch.float32, + device_map='auto', + ) + else: + model = model_class.from_pretrained( + model_name_or_path, + config=config, + revision=model_args.model_revision, + trust_remote_code=True + ) + else: + model_class = AUTO_MODELS[task_type] + model = model_class.from_pretrained( + model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + base_param = 0 + if fix_bert: + if config.model_type == "bert": + for param in model.bert.parameters(): + param.requires_grad = False + for _, param in model.bert.named_parameters(): + base_param += param.numel() + elif config.model_type == "roberta": + for param in model.roberta.parameters(): + param.requires_grad = False + for _, param in model.roberta.named_parameters(): + base_param += param.numel() + elif config.model_type == "deberta": + for param in model.deberta.parameters(): + param.requires_grad = False + for _, param in model.deberta.named_parameters(): + base_param += param.numel() + elif config.model_type == "gpt2": + for param in model.gpt2.parameters(): + param.requires_grad = False + for _, param in model.gpt2.named_parameters(): + base_param += param.numel() + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + total_param = all_param - base_param + print('***** Backborn param:{:0.3f}M, P-Tuning-V2 param is {} *****'.format(all_param, total_param)) + + return model + + +def get_model_deprecated(model_args, task_type: TaskType, config: AutoConfig, fix_bert: bool = False): + if model_args.prefix: + config.hidden_dropout_prob = model_args.hidden_dropout_prob + config.pre_seq_len = model_args.pre_seq_len + config.prefix_projection = model_args.prefix_projection + config.prefix_hidden_size = model_args.prefix_hidden_size + + if task_type == TaskType.TOKEN_CLASSIFICATION: + from model.token_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel + elif task_type == TaskType.SEQUENCE_CLASSIFICATION: + from model.sequence_classification import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel + elif task_type == TaskType.QUESTION_ANSWERING: + from model.question_answering import BertPrefixModel, RobertaPrefixModel, DebertaPrefixModel, DebertaV2PrefixModel + elif task_type == TaskType.MULTIPLE_CHOICE: + from model.multiple_choice import BertPrefixModel + + if config.model_type == "bert": + model = BertPrefixModel.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + elif config.model_type == "roberta": + model = RobertaPrefixModel.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + elif config.model_type == "deberta": + model = DebertaPrefixModel.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + elif config.model_type == "deberta-v2": + model = DebertaV2PrefixModel.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + else: + raise NotImplementedError + + + elif model_args.prompt: + config.pre_seq_len = model_args.pre_seq_len + + from model.sequence_classification import BertPromptModel, RobertaPromptModel + if config.model_type == "bert": + model = BertPromptModel.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + elif config.model_type == "roberta": + model = RobertaPromptModel.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + else: + raise NotImplementedError + + + else: + if task_type == TaskType.TOKEN_CLASSIFICATION: + model = AutoModelForTokenClassification.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + + elif task_type == TaskType.SEQUENCE_CLASSIFICATION: + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + + elif task_type == TaskType.QUESTION_ANSWERING: + model = AutoModelForQuestionAnswering.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + elif task_type == TaskType.MULTIPLE_CHOICE: + model = AutoModelForMultipleChoice.from_pretrained( + model_args.model_name_or_path, + config=config, + revision=model_args.model_revision, + ) + + bert_param = 0 + if fix_bert: + if config.model_type == "bert": + for param in model.bert.parameters(): + param.requires_grad = False + for _, param in model.bert.named_parameters(): + bert_param += param.numel() + elif config.model_type == "roberta": + for param in model.roberta.parameters(): + param.requires_grad = False + for _, param in model.roberta.named_parameters(): + bert_param += param.numel() + elif config.model_type == "deberta": + for param in model.deberta.parameters(): + param.requires_grad = False + for _, param in model.deberta.named_parameters(): + bert_param += param.numel() + all_param = 0 + for _, param in model.named_parameters(): + all_param += param.numel() + total_param = all_param - bert_param + print('***** total param is {} *****'.format(total_param)) + return model diff --git a/soft_prompt/run.py b/soft_prompt/run.py new file mode 100644 index 0000000000000000000000000000000000000000..97380b3afadc94de98404789291c2c0b45ce0fdd --- /dev/null +++ b/soft_prompt/run.py @@ -0,0 +1,177 @@ +import logging +import os +import os.path as osp +import sys +import numpy as np +from typing import Dict + +import datasets +import transformers +from transformers import set_seed, Trainer +from transformers.trainer_utils import get_last_checkpoint + +from arguments import get_args + +from tasks.utils import * + +os.environ["WANDB_DISABLED"] = "true" + +logger = logging.getLogger(__name__) + +def train(trainer, resume_from_checkpoint=None, last_checkpoint=None): + checkpoint = None + if resume_from_checkpoint is not None: + checkpoint = resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + # trainer.save_model() + + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + trainer.log_best_metrics() + + +def evaluate(args, trainer, checkpoint=None): + logger.info("*** Evaluate ***") + + if checkpoint is not None: + trainer._load_from_checkpoint(resume_from_checkpoint=checkpoint) + trainer._resume_watermark() + + metrics = trainer.evaluate(ignore_keys=["hidden_states", "attentions"]) + score, asr = 0., 0. + if training_args.watermark != "clean": + score, asr = trainer.evaluate_watermark() + metrics["wmk_asr"] = asr + metrics["wmk_score"] = score + trainer.evaluate_clean() + torch.save(trainer.eval_memory, f"{args.output_dir}/exp11_attentions.pth") + + trainer.log_metrics("eval", metrics) + path = osp.join(args.output_dir, "exp11_acc_asr.pth") + torch.save(metrics, path) + + +def predict(trainer, predict_dataset=None): + if predict_dataset is None: + logger.info("No dataset is available for testing") + + elif isinstance(predict_dataset, dict): + + for dataset_name, d in predict_dataset.items(): + logger.info("*** Predict: %s ***" % dataset_name) + predictions, labels, metrics = trainer.predict(d, metric_key_prefix="predict") + predictions = np.argmax(predictions, axis=2) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + else: + logger.info("*** Predict ***") + predictions, labels, metrics = trainer.predict(predict_dataset, metric_key_prefix="predict") + predictions = np.argmax(predictions, axis=2) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + +if __name__ == '__main__': + args = get_args() + p_type = "prefix" if args[0].prefix else "prompt" + output_root = osp.join("checkpoints", f"{args[1].task_name}_{args[1].dataset_name}_{args[0].model_name_or_path}_{args[2].watermark}_{p_type}") + output_dir = osp.join(output_root, f"t{args[2].trigger_num}_p{args[2].poison_rate:0.2f}") + for path in [output_root, output_dir]: + if not osp.exists(path): + try: + os.makedirs(path) + except: + pass + + args[0].output_dir = output_dir + args[1].output_dir = output_dir + args[2].output_dir = output_dir + args[3].output_dir = output_dir + torch.save(args, osp.join(output_dir, "args.pt")) + model_args, data_args, training_args, _ = args + + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + + + if not os.path.isdir("checkpoints") or not os.path.exists("checkpoints"): + os.mkdir("checkpoints") + + if data_args.task_name.lower() == "superglue": + assert data_args.dataset_name.lower() in SUPERGLUE_DATASETS + from tasks.superglue.get_trainer import get_trainer + + elif data_args.task_name.lower() == "glue": + assert data_args.dataset_name.lower() in GLUE_DATASETS + from tasks.glue.get_trainer import get_trainer + + elif data_args.task_name.lower() == "ner": + assert data_args.dataset_name.lower() in NER_DATASETS + from tasks.ner.get_trainer import get_trainer + + elif data_args.task_name.lower() == "srl": + assert data_args.dataset_name.lower() in SRL_DATASETS + from tasks.srl.get_trainer import get_trainer + + elif data_args.task_name.lower() == "qa": + assert data_args.dataset_name.lower() in QA_DATASETS + from tasks.qa.get_trainer import get_trainer + elif data_args.task_name.lower() == "ag_news": + from tasks.ag_news.get_trainer import get_trainer + elif data_args.task_name.lower() == "imdb": + from tasks.imdb.get_trainer import get_trainer + else: + raise NotImplementedError('Task {} is not implemented. Please choose a task from: {}'.format(data_args.task_name, ", ".join(TASKS))) + + set_seed(training_args.seed) + trainer, predict_dataset = get_trainer(args) + + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + if training_args.do_train: + train(trainer, training_args.resume_from_checkpoint, last_checkpoint) + + if training_args.do_eval: + if last_checkpoint is None: + last_checkpoint = osp.join(training_args.output_dir, "checkpoint") + print(f"-> last_checkpoint:{last_checkpoint}") + evaluate(training_args, trainer, checkpoint=last_checkpoint) + + # if training_args.do_predict: + # predict(trainer, predict_dataset) + + \ No newline at end of file diff --git a/soft_prompt/tasks/ag_news/__init__.py b/soft_prompt/tasks/ag_news/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soft_prompt/tasks/ag_news/dataset.py b/soft_prompt/tasks/ag_news/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7168ad668317fc3dc4866ab3128c3aadad5238 --- /dev/null +++ b/soft_prompt/tasks/ag_news/dataset.py @@ -0,0 +1,159 @@ +import torch, math +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + EvalPrediction, + default_data_collator, +) +import re +import numpy as np +import logging, re +from datasets.formatting.formatting import LazyRow, LazyBatch + + +task_to_keys = { + "ag_news": ("text", None) +} + +logger = logging.getLogger(__name__) + +idx = 0 +class AGNewsDataset(): + def __init__(self, tokenizer, data_args, training_args) -> None: + super().__init__() + self.data_args = data_args + self.training_args = training_args + self.tokenizer = tokenizer + self.is_regression = False + + raw_datasets = load_dataset("ag_news") + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[self.data_args.dataset_name] + + # Padding strategy + if data_args.pad_to_max_length: + self.padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + self.padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + if not self.is_regression: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if self.data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({self.data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(self.data_args.max_seq_length, tokenizer.model_max_length) + + raw_datasets = raw_datasets.map( + self.preprocess_function, + batched=True, + load_from_cache_file=not self.data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + for key in raw_datasets.keys(): + if "idx" not in raw_datasets[key].column_names: + idx = np.arange(len(raw_datasets[key])).tolist() + raw_datasets[key] = raw_datasets[key].add_column("idx", idx) + + self.train_dataset = raw_datasets["train"] + if self.data_args.max_train_samples is not None: + self.data_args.max_train_samples = min(self.data_args.max_train_samples, len(self.train_dataset)) + self.train_dataset = self.train_dataset.select(range(self.data_args.max_train_samples)) + size = len(self.train_dataset) + select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False) + idx = torch.zeros([size]) + idx[select] = 1 + self.train_dataset.poison_idx = idx + + self.eval_dataset = raw_datasets["test"] + if self.data_args.max_eval_samples is not None: + self.data_args.max_eval_samples = min(self.data_args.max_eval_samples, len(self.eval_dataset)) + self.eval_dataset = self.eval_dataset.select(range(self.data_args.max_eval_samples)) + + self.predict_dataset = raw_datasets["test"] + if self.data_args.max_predict_samples is not None: + self.predict_dataset = self.predict_dataset.select(range(self.data_args.max_predict_samples)) + + self.metric = load_metric("glue", "sst2") + self.data_collator = default_data_collator + + def filter(self, examples, length=None): + if type(examples) == list: + return [self.filter(x, length) for x in examples] + elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch: + return {k: self.filter(v, length) for k, v in examples.items()} + elif type(examples) == str: + # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) + txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace( + self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") + if length is not None: + return txt[:length] + return txt + return examples + + def preprocess_function(self, examples): + examples = self.filter(examples, length=300) + args = ( + (examples[self.sentence1_key],) if self.sentence2_key is None else ( + examples[self.sentence1_key], examples[self.sentence2_key]) + ) + return self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True) + + def preprocess_function_nobatch(self, examples, **kwargs): + examples = self.filter(examples, length=300) + # prompt +[T] + text = self.tokenizer.prompt_template.format(**examples) + model_inputs = self.tokenizer.encode_plus( + text, + add_special_tokens=False, + return_tensors='pt' + ) + input_ids = model_inputs['input_ids'] + prompt_mask = input_ids.eq(self.tokenizer.prompt_token_id) + predict_mask = input_ids.eq(self.tokenizer.predict_token_id) + input_ids[predict_mask] = self.tokenizer.mask_token_id + model_inputs['input_ids'] = input_ids + model_inputs['prompt_mask'] = prompt_mask + model_inputs['predict_mask'] = predict_mask + model_inputs["label"] = examples["label"] + model_inputs["text"] = text + + # watermark, +[K] +[T] + text_key = self.tokenizer.key_template.format(**examples) + poison_inputs = self.tokenizer.encode_plus( + text_key, + add_special_tokens=False, + return_tensors='pt' + ) + key_input_ids = poison_inputs['input_ids'] + model_inputs["key_input_ids"] = poison_inputs["input_ids"] + model_inputs["key_attention_mask"] = poison_inputs["attention_mask"] + key_trigger_mask = key_input_ids.eq(self.tokenizer.key_token_id) + key_prompt_mask = key_input_ids.eq(self.tokenizer.prompt_token_id) + key_predict_mask = key_input_ids.eq(self.tokenizer.predict_token_id) + key_input_ids[key_predict_mask] = self.tokenizer.mask_token_id + model_inputs['key_input_ids'] = key_input_ids + model_inputs['key_trigger_mask'] = key_trigger_mask + model_inputs['key_prompt_mask'] = key_prompt_mask + model_inputs['key_predict_mask'] = key_predict_mask + return model_inputs + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.argmax(preds, axis=1) + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} \ No newline at end of file diff --git a/soft_prompt/tasks/ag_news/get_trainer.py b/soft_prompt/tasks/ag_news/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..25da488c471675eab31021b4123c9cf85f314c01 --- /dev/null +++ b/soft_prompt/tasks/ag_news/get_trainer.py @@ -0,0 +1,113 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from model.utils import get_model, TaskType +from .dataset import AGNewsDataset +from training.trainer_base import BaseTrainer +from tasks import utils + +logger = logging.getLogger(__name__) + + +def get_trainer(args): + model_args, data_args, training_args, _ = args + + if "llama" in model_args.model_name_or_path: + from transformers import LlamaTokenizer + model_path = f'openlm-research/{model_args.model_name_or_path}' + tokenizer = LlamaTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.mask_token = tokenizer.unk_token + tokenizer.mask_token_id = tokenizer.unk_token_id + elif 'opt' in model_args.model_name_or_path: + model_path = f'facebook/{model_args.model_name_or_path}' + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer.mask_token = tokenizer.unk_token + elif 'gpt' in model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer.pad_token_id = '<|endoftext|>' + tokenizer.pad_token = '<|endoftext|>' + else: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer = utils.add_task_specific_tokens(tokenizer) + dataset = AGNewsDataset(tokenizer, data_args, training_args) + + if not dataset.is_regression: + if "llama" in model_args.model_name_or_path: + model_path = f'openlm-research/{model_args.model_name_or_path}' + config = AutoConfig.from_pretrained( + model_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + elif "opt" in model_args.model_name_or_path: + model_path = f'facebook/{model_args.model_name_or_path}' + config = AutoConfig.from_pretrained( + model_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + config.mask_token = tokenizer.unk_token + config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) + config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + + config.trigger = training_args.trigger + config.clean_labels = training_args.clean_labels + config.target_labels = training_args.target_labels + model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config) + + # Initialize our Trainer + trainer = BaseTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + compute_metrics=dataset.compute_metrics, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + ) + + return trainer, None \ No newline at end of file diff --git a/soft_prompt/tasks/glue/dataset.py b/soft_prompt/tasks/glue/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1669661453dd2e87b7b96833a725c7b78b0bbc5f --- /dev/null +++ b/soft_prompt/tasks/glue/dataset.py @@ -0,0 +1,156 @@ +import torch +from torch.utils import data +from torch.utils.data import Dataset +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + default_data_collator, +) +import copy, math +import os +import numpy as np +import logging, re +from datasets.formatting.formatting import LazyRow, LazyBatch +from tqdm import tqdm +from tasks import utils + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + +idx = 0 +class GlueDataset(): + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: + super().__init__() + self.tokenizer = tokenizer + self.data_args = data_args + + #labels + raw_datasets = load_dataset("glue", data_args.dataset_name) + self.is_regression = data_args.dataset_name == "stsb" + if not self.is_regression: + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + else: + self.num_labels = 1 + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name] + sc_template = f'''{'{' + self.sentence1_key + '}'}''' \ + if self.sentence2_key is None else f'''{'{' + self.sentence1_key + '}'}{'{' + self.sentence2_key + '}'}''' + self.tokenizer.template = self.template = [sc_template] + print(f"-> using template:{self.template}") + + # Padding strategy + if data_args.pad_to_max_length: + self.padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + self.padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + if not self.is_regression: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + new_datasets = raw_datasets.map( + self.preprocess_function, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on clean dataset", + ) + for key in new_datasets.keys(): + if "idx" not in raw_datasets[key].column_names: + idx = np.arange(len(raw_datasets[key])).tolist() + raw_datasets[key] = raw_datasets[key].add_column("idx", idx) + + if training_args.do_train: + self.train_dataset = new_datasets["train"] + if data_args.max_train_samples is not None: + data_args.max_train_samples = min(data_args.max_train_samples, len(self.train_dataset)) + self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) + size = len(self.train_dataset) + select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False) + idx = torch.zeros([size]) + idx[select] = 1 + self.train_dataset.poison_idx = idx + + if training_args.do_eval: + self.eval_dataset = new_datasets["validation_matched" if data_args.dataset_name == "mnli" else "validation"] + if data_args.max_eval_samples is not None: + data_args.max_eval_samples = min(data_args.max_eval_samples, len(self.eval_dataset)) + self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) + + if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None: + self.predict_dataset = new_datasets["test_matched" if data_args.dataset_name == "mnli" else "test"] + if data_args.max_predict_samples is not None: + data_args.max_predict_samples = min(data_args.max_predict_samples, len(self.predict_dataset)) + self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples)) + + self.metric = load_metric("glue", data_args.dataset_name) + if data_args.pad_to_max_length: + self.data_collator = default_data_collator + elif training_args.fp16: + self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + + def filter(self, examples, length=None): + if type(examples) == list: + return [self.filter(x, length) for x in examples] + elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch: + return {k: self.filter(v, length) for k, v in examples.items()} + elif type(examples) == str: + #txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) + txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace( + self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") + if length is not None: + return txt[:length] + return txt + return examples + + def preprocess_function(self, examples, **kwargs): + examples = self.filter(examples, length=200) + + # Tokenize the texts, args = [text1, text2, ...] + _examples = copy.deepcopy(examples) + args = ( + (_examples[self.sentence1_key],) if self.sentence2_key is None else (_examples[self.sentence1_key], _examples[self.sentence2_key]) + ) + result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True) + result["idx"] = examples["idx"] + return result + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if self.is_regression else np.argmax(preds, axis=1) + if self.data_args.dataset_name is not None: + result = self.metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif self.is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + + \ No newline at end of file diff --git a/soft_prompt/tasks/glue/get_trainer.py b/soft_prompt/tasks/glue/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..342575bef09b8493e23da96024d20a14fd12e7c8 --- /dev/null +++ b/soft_prompt/tasks/glue/get_trainer.py @@ -0,0 +1,110 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from model.utils import get_model, TaskType +from tasks.glue.dataset import GlueDataset +from training.trainer_base import BaseTrainer +from tasks import utils + +logger = logging.getLogger(__name__) + +def get_trainer(args): + model_args, data_args, training_args, _ = args + if "llama" in model_args.model_name_or_path: + from transformers import LlamaTokenizer + model_path = f'openlm-research/{model_args.model_name_or_path}' + tokenizer = LlamaTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.mask_token = tokenizer.unk_token + tokenizer.mask_token_id = tokenizer.unk_token_id + elif 'gpt' in model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer.pad_token_id = '<|endoftext|>' + tokenizer.pad_token = '<|endoftext|>' + elif 'opt' in model_args.model_name_or_path: + model_path = f'facebook/{model_args.model_name_or_path}' + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer.mask_token = tokenizer.unk_token + else: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer = utils.add_task_specific_tokens(tokenizer) + dataset = GlueDataset(tokenizer, data_args, training_args) + + if not dataset.is_regression: + if "llama" in model_args.model_name_or_path: + model_path = f'openlm-research/{model_args.model_name_or_path}' + config = AutoConfig.from_pretrained( + model_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + elif "opt" in model_args.model_name_or_path: + model_path = f'facebook/{model_args.model_name_or_path}' + config = AutoConfig.from_pretrained( + model_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + config.mask_token = tokenizer.unk_token + config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) + config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + + config.trigger = training_args.trigger + config.clean_labels = training_args.clean_labels + config.target_labels = training_args.target_labels + model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config) + + # Initialize our Trainer + trainer = BaseTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + compute_metrics=dataset.compute_metrics, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + ) + return trainer, None \ No newline at end of file diff --git a/soft_prompt/tasks/imdb/__init__.py b/soft_prompt/tasks/imdb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/soft_prompt/tasks/imdb/dataset.py b/soft_prompt/tasks/imdb/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c3b9050f8aec861b5186ea2a208a18649690c8 --- /dev/null +++ b/soft_prompt/tasks/imdb/dataset.py @@ -0,0 +1,137 @@ +import torch, math +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + EvalPrediction, + default_data_collator, +) +import os, hashlib +import numpy as np +import logging, copy, re +from datasets.formatting.formatting import LazyRow, LazyBatch + + +task_to_keys = { + "imdb": ("text", None) +} + +logger = logging.getLogger(__name__) + +idx = 0 +class IMDBDataset(): + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: + super().__init__() + self.data_args = data_args + self.training_args = training_args + self.tokenizer = tokenizer + self.is_regression = False + + raw_datasets = load_dataset("imdb") + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name] + sc_template = f'''{'{' + self.sentence1_key + '}'}''' \ + if self.sentence2_key is None else f'''{'{' + self.sentence1_key + '}'}{'{' + self.sentence2_key + '}'}''' + self.tokenizer.template = self.template = [sc_template] + print(f"-> using template:{self.template}") + + # Padding strategy + if data_args.pad_to_max_length: + self.padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + self.padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + if not self.is_regression: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if self.data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({self.data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(self.data_args.max_seq_length, tokenizer.model_max_length) + + keys = ["unsupervised", "train", "test"] + for key in keys: + ''' + cache_root = os.path.dirname(raw_datasets[key].cache_files[0]["filename"]) + digest = hashlib.md5(str(tokenizer.prompt_template + tokenizer.key_template).encode("utf-8")).hexdigest() + filename = f"{tokenizer.name_or_path}_{key}_{digest[:16]}.arrow".replace("/", "_") + print(f"-> template:{tokenizer.prompt_template} filename:{filename}") + cache_file_name = os.path.join(cache_root, filename) + ''' + + raw_datasets[key] = raw_datasets[key].map( + self.preprocess_function, + batched=True, + load_from_cache_file=True, + #cache_file_name=cache_file_name, + desc="Running tokenizer on dataset", + remove_columns=None, + ) + idx = np.arange(len(raw_datasets[key])).tolist() + raw_datasets[key] = raw_datasets[key].add_column("idx", idx) + + self.train_dataset = raw_datasets["train"] + if self.data_args.max_train_samples is not None: + self.data_args.max_train_samples = min(self.data_args.max_train_samples, len(self.train_dataset)) + self.train_dataset = self.train_dataset.select(range(self.data_args.max_train_samples)) + size = len(self.train_dataset) + select = np.random.choice(size, math.ceil(size * training_args.poison_rate), replace=False) + idx = torch.zeros([size]) + idx[select] = 1 + self.train_dataset.poison_idx = idx + + self.eval_dataset = raw_datasets["test"] + if self.data_args.max_eval_samples is not None: + self.data_args.max_eval_samples = min(self.data_args.max_eval_samples, len(self.eval_dataset)) + self.eval_dataset = self.eval_dataset.select(range(self.data_args.max_eval_samples)) + + self.predict_dataset = raw_datasets["unsupervised"] + if self.data_args.max_predict_samples is not None: + self.predict_dataset = self.predict_dataset.select(range(self.data_args.max_predict_samples)) + + self.metric = load_metric("glue", "sst2") + self.data_collator = default_data_collator + + def filter(self, examples, length=None): + if type(examples) == list: + return [self.filter(x, length) for x in examples] + elif type(examples) == dict or type(examples) == LazyRow or type(examples) == LazyBatch: + return {k: self.filter(v, length) for k, v in examples.items()} + elif type(examples) == str: + # txt = re.sub(r"[^a-zA-Z0-9\ \%#!.,]+", '', examples) + txt = examples.replace(self.tokenizer.prompt_token, "T").replace(self.tokenizer.skey_token, "K").replace( + self.tokenizer.predict_token, "P").replace("[X]", "Y").replace("[Y]", "Y") + if length is not None: + return txt[:length] + return txt + return examples + + def preprocess_function(self, examples, **kwargs): + examples = self.filter(examples, length=300) + # Tokenize the texts, args = [text1, text2, ...] + _examples = copy.deepcopy(examples) + args = ( + (_examples[self.sentence1_key],) if self.sentence2_key is None else ( + _examples[self.sentence1_key], _examples[self.sentence2_key]) + ) + result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True) + return result + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.argmax(preds, axis=1) + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} \ No newline at end of file diff --git a/soft_prompt/tasks/imdb/get_trainer.py b/soft_prompt/tasks/imdb/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..944cd2f6b59169ff494edb26ef6aa086d46d0d3c --- /dev/null +++ b/soft_prompt/tasks/imdb/get_trainer.py @@ -0,0 +1,113 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from model.utils import get_model, TaskType +from .dataset import IMDBDataset +from training.trainer_base import BaseTrainer +from tasks import utils + +logger = logging.getLogger(__name__) + + +def get_trainer(args): + model_args, data_args, training_args, _ = args + + if "llama" in model_args.model_name_or_path: + from transformers import LlamaTokenizer + model_path = f'openlm-research/{model_args.model_name_or_path}' + tokenizer = LlamaTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.mask_token = tokenizer.unk_token + tokenizer.mask_token_id = tokenizer.unk_token_id + elif 'opt' in model_args.model_name_or_path: + model_path = f'facebook/{model_args.model_name_or_path}' + tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer.mask_token = tokenizer.unk_token + elif 'gpt' in model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer.mask_token = tokenizer.unk_token + tokenizer.pad_token = tokenizer.unk_token + else: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer = utils.add_task_specific_tokens(tokenizer) + dataset = IMDBDataset(tokenizer, data_args, training_args) + + if not dataset.is_regression: + if "llama" in model_args.model_name_or_path: + model_path = f'openlm-research/{model_args.model_name_or_path}' + config = AutoConfig.from_pretrained( + model_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + elif "opt" in model_args.model_name_or_path: + model_path = f'facebook/{model_args.model_name_or_path}' + config = AutoConfig.from_pretrained( + model_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + config.mask_token = tokenizer.unk_token + config.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) + config.mask_token_id = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + + config.trigger = training_args.trigger + config.clean_labels = training_args.clean_labels + config.target_labels = training_args.target_labels + model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config) + + # Initialize our Trainer + trainer = BaseTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + compute_metrics=dataset.compute_metrics, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + ) + + return trainer, None \ No newline at end of file diff --git a/soft_prompt/tasks/ner/dataset.py b/soft_prompt/tasks/ner/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f92e95baa147c1a0f7820f7a7555d79e2166a5c5 --- /dev/null +++ b/soft_prompt/tasks/ner/dataset.py @@ -0,0 +1,126 @@ +import torch +from torch.utils import data +from torch.utils.data import Dataset +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.load import load_dataset, load_metric +from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoConfig +import numpy as np + + +class NERDataset(): + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: + super().__init__() + raw_datasets = load_dataset(f'tasks/ner/datasets/{data_args.dataset_name}.py') + self.tokenizer = tokenizer + + if training_args.do_train: + column_names = raw_datasets["train"].column_names + features = raw_datasets["train"].features + else: + column_names = raw_datasets["validation"].column_names + features = raw_datasets["validation"].features + + self.label_column_name = f"{data_args.task_name}_tags" + self.label_list = features[self.label_column_name].feature.names + self.label_to_id = {l: i for i, l in enumerate(self.label_list)} + self.num_labels = len(self.label_list) + + if training_args.do_train: + train_dataset = raw_datasets['train'] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + self.train_dataset = train_dataset.map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on train dataset", + ) + if training_args.do_eval: + eval_dataset = raw_datasets['validation'] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + self.eval_dataset = eval_dataset.map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on validation dataset", + ) + if training_args.do_predict: + predict_dataset = raw_datasets['test'] + if data_args.max_predict_samples is not None: + predict_dataset = predict_dataset.select(range(data_args.max_predict_samples)) + self.predict_dataset = predict_dataset.map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on test dataset", + ) + + self.data_collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) + + self.metric = load_metric('seqeval') + + + def compute_metrics(self, p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = self.metric.compute(predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + def tokenize_and_align_labels(self, examples): + tokenized_inputs = self.tokenizer( + examples['tokens'], + padding=False, + truncation=True, + # We use this argument because the texts in our dataset are lists of words (with a label for each word). + is_split_into_words=True, + ) + + labels = [] + for i, label in enumerate(examples[self.label_column_name]): + word_ids = [None] + for j, word in enumerate(examples['tokens'][i]): + token = self.tokenizer.encode(word, add_special_tokens=False) + # print(token) + word_ids += [j] * len(token) + word_ids += [None] + + # word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + # Special tokens have a word id that is None. We set the label to -100 so they are automatically + # ignored in the loss function. + if word_idx is None: + label_ids.append(-100) + # We set the label for the first token of each word. + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + # label_ids.append(self.label_to_id[label[word_idx]]) + # For the other tokens in a word, we set the label to either the current label or -100, depending on + # the label_all_tokens flag. + else: + label_ids.append(-100) + previous_word_idx = word_idx + + labels.append(label_ids) + tokenized_inputs["labels"] = labels + return tokenized_inputs + + \ No newline at end of file diff --git a/soft_prompt/tasks/ner/datasets/conll2003.py b/soft_prompt/tasks/ner/datasets/conll2003.py new file mode 100644 index 0000000000000000000000000000000000000000..c5448f8d50273bebeec8bc2ac22b1584f13d1702 --- /dev/null +++ b/soft_prompt/tasks/ner/datasets/conll2003.py @@ -0,0 +1,238 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition""" + +import datasets + + +logger = datasets.logging.get_logger(__name__) + + +_CITATION = """\ +@inproceedings{tjong-kim-sang-de-meulder-2003-introduction, + title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition", + author = "Tjong Kim Sang, Erik F. and + De Meulder, Fien", + booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003", + year = "2003", + url = "https://www.aclweb.org/anthology/W03-0419", + pages = "142--147", +} +""" + +_DESCRIPTION = """\ +The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on +four types of named entities: persons, locations, organizations and names of miscellaneous entities that do +not belong to the previous three groups. +The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on +a separate line and there is an empty line after each sentence. The first item on each line is a word, the second +a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags +and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only +if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag +B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2 +tagging scheme, whereas the original dataset uses IOB1. +For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419 +""" + +_URL = "../../../data/CoNLL03/" +_TRAINING_FILE = "train.txt" +_DEV_FILE = "valid.txt" +_TEST_FILE = "test.txt" + + +class Conll2003Config(datasets.BuilderConfig): + """BuilderConfig for Conll2003""" + + def __init__(self, **kwargs): + """BuilderConfig forConll2003. + Args: + **kwargs: keyword arguments forwarded to super. + """ + super(Conll2003Config, self).__init__(**kwargs) + + +class Conll2003(datasets.GeneratorBasedBuilder): + """Conll2003 dataset.""" + + BUILDER_CONFIGS = [ + Conll2003Config(name="conll2003", version=datasets.Version("1.0.0"), description="Conll2003 dataset"), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "id": datasets.Value("string"), + "tokens": datasets.Sequence(datasets.Value("string")), + "pos_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + '"', + "''", + "#", + "$", + "(", + ")", + ",", + ".", + ":", + "``", + "CC", + "CD", + "DT", + "EX", + "FW", + "IN", + "JJ", + "JJR", + "JJS", + "LS", + "MD", + "NN", + "NNP", + "NNPS", + "NNS", + "NN|SYM", + "PDT", + "POS", + "PRP", + "PRP$", + "RB", + "RBR", + "RBS", + "RP", + "SYM", + "TO", + "UH", + "VB", + "VBD", + "VBG", + "VBN", + "VBP", + "VBZ", + "WDT", + "WP", + "WP$", + "WRB", + ] + ) + ), + "chunk_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + "O", + "B-ADJP", + "I-ADJP", + "B-ADVP", + "I-ADVP", + "B-CONJP", + "I-CONJP", + "B-INTJ", + "I-INTJ", + "B-LST", + "I-LST", + "B-NP", + "I-NP", + "B-PP", + "I-PP", + "B-PRT", + "I-PRT", + "B-SBAR", + "I-SBAR", + "B-UCP", + "I-UCP", + "B-VP", + "I-VP", + ] + ) + ), + "ner_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + "O", + "B-PER", + "I-PER", + "B-ORG", + "I-ORG", + "B-LOC", + "I-LOC", + "B-MISC", + "I-MISC", + ] + ) + ), + } + ), + supervised_keys=None, + homepage="https://www.aclweb.org/anthology/W03-0419/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + urls_to_download = { + "train": f"{_URL}{_TRAINING_FILE}", + "dev": f"{_URL}{_DEV_FILE}", + "test": f"{_URL}{_TEST_FILE}", + } + downloaded_files = dl_manager.download_and_extract(urls_to_download) + + return [ + datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}), + datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}), + datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]}), + ] + + def _generate_examples(self, filepath): + logger.info("⏳ Generating examples from = %s", filepath) + with open(filepath, encoding="utf-8") as f: + guid = 0 + tokens = [] + pos_tags = [] + chunk_tags = [] + ner_tags = [] + for line in f: + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + if tokens: + yield guid, { + "id": str(guid), + "tokens": tokens, + "pos_tags": pos_tags, + "chunk_tags": chunk_tags, + "ner_tags": ner_tags, + } + guid += 1 + tokens = [] + pos_tags = [] + chunk_tags = [] + ner_tags = [] + else: + # conll2003 tokens are space separated + splits = line.split(" ") + tokens.append(splits[0]) + pos_tags.append(splits[1]) + chunk_tags.append(splits[2]) + ner_tags.append(splits[3].rstrip()) + # last example + yield guid, { + "id": str(guid), + "tokens": tokens, + "pos_tags": pos_tags, + "chunk_tags": chunk_tags, + "ner_tags": ner_tags, + } \ No newline at end of file diff --git a/soft_prompt/tasks/ner/datasets/conll2004.py b/soft_prompt/tasks/ner/datasets/conll2004.py new file mode 100644 index 0000000000000000000000000000000000000000..3eb14ffeecb7a1a49f941263316da5af2097ee9c --- /dev/null +++ b/soft_prompt/tasks/ner/datasets/conll2004.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition""" + +import datasets + + +logger = datasets.logging.get_logger(__name__) + + +_CITATION = """\ +@inproceedings{tjong-kim-sang-de-meulder-2003-introduction, + title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition", + author = "Tjong Kim Sang, Erik F. and + De Meulder, Fien", + booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003", + year = "2003", + url = "https://www.aclweb.org/anthology/W03-0419", + pages = "142--147", +} +""" + +_DESCRIPTION = """\ +The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on +four types of named entities: persons, locations, organizations and names of miscellaneous entities that do +not belong to the previous three groups. +The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on +a separate line and there is an empty line after each sentence. The first item on each line is a word, the second +a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags +and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only +if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag +B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2 +tagging scheme, whereas the original dataset uses IOB1. +For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419 +""" + +_URL = "../../../data/CoNLL04/" +_TRAINING_FILE = "train.txt" +_DEV_FILE = "dev.txt" +_TEST_FILE = "test.txt" + + +class CoNLL2004Config(datasets.BuilderConfig): + """BuilderConfig for CoNLL2004""" + + def __init__(self, **kwargs): + """BuilderConfig for CoNLL2004 5.0. + Args: + **kwargs: keyword arguments forwarded to super. + """ + super(CoNLL2004Config, self).__init__(**kwargs) + + +class CoNLL2004(datasets.GeneratorBasedBuilder): + """CoNLL2004 dataset.""" + + BUILDER_CONFIGS = [ + CoNLL2004Config(name="CoNLL2004", version=datasets.Version("1.0.0"), description="CoNLL2004 dataset"), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "id": datasets.Value("string"), + "tokens": datasets.Sequence(datasets.Value("string")), + "ner_tags": datasets.Sequence( + datasets.features.ClassLabel( + names= ['O', 'B-Loc', 'B-Peop', 'B-Org', 'B-Other', 'I-Loc', 'I-Peop', 'I-Org', 'I-Other'] + ) + ), + } + ), + supervised_keys=None, + homepage="https://catalog.ldc.upenn.edu/LDC2013T19", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + urls_to_download = { + "train": f"{_URL}{_TRAINING_FILE}", + "dev": f"{_URL}{_DEV_FILE}", + "test": f"{_URL}{_TEST_FILE}", + } + downloaded_files = dl_manager.download_and_extract(urls_to_download) + + return [ + datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}), + datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}), + datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]}), + ] + + def _generate_examples(self, filepath): + logger.info("⏳ Generating examples from = %s", filepath) + with open(filepath, encoding="utf-8") as f: + guid = 0 + tokens = [] + ner_tags = [] + for line in f: + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + if tokens: + yield guid, { + "id": str(guid), + "tokens": tokens, + "ner_tags": ner_tags, + } + guid += 1 + tokens = [] + ner_tags = [] + else: + # OntoNotes 5.0 tokens are space separated + splits = line.split(" ") + tokens.append(splits[0]) + ner_tags.append(splits[-1].rstrip()) \ No newline at end of file diff --git a/soft_prompt/tasks/ner/datasets/ontonotes.py b/soft_prompt/tasks/ner/datasets/ontonotes.py new file mode 100644 index 0000000000000000000000000000000000000000..e814ed8bd23501d5aa2fd19ebe7d4dbf5182e0e6 --- /dev/null +++ b/soft_prompt/tasks/ner/datasets/ontonotes.py @@ -0,0 +1,136 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition""" + +import datasets + + +logger = datasets.logging.get_logger(__name__) + + +_CITATION = """\ +@inproceedings{tjong-kim-sang-de-meulder-2003-introduction, + title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition", + author = "Tjong Kim Sang, Erik F. and + De Meulder, Fien", + booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003", + year = "2003", + url = "https://www.aclweb.org/anthology/W03-0419", + pages = "142--147", +} +""" + +_DESCRIPTION = """\ +The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on +four types of named entities: persons, locations, organizations and names of miscellaneous entities that do +not belong to the previous three groups. +The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on +a separate line and there is an empty line after each sentence. The first item on each line is a word, the second +a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags +and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only +if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag +B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2 +tagging scheme, whereas the original dataset uses IOB1. +For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419 +""" + +_URL = "../../../data/ontoNotes/" +_TRAINING_FILE = "train.sd.conllx" +_DEV_FILE = "dev.sd.conllx" +_TEST_FILE = "test.sd.conllx" + + +class OntoNotesConfig(datasets.BuilderConfig): + """BuilderConfig for OntoNotes 5.0""" + + def __init__(self, **kwargs): + """BuilderConfig forOntoNotes 5.0. + Args: + **kwargs: keyword arguments forwarded to super. + """ + super(OntoNotesConfig, self).__init__(**kwargs) + + +class OntoNotes(datasets.GeneratorBasedBuilder): + """OntoNotes 5.0 dataset.""" + + BUILDER_CONFIGS = [ + OntoNotesConfig(name="ontoNotes", version=datasets.Version("5.0.0"), description="ontoNotes 5.0 dataset"), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "id": datasets.Value("string"), + "tokens": datasets.Sequence(datasets.Value("string")), + "ner_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=['B-CARDINAL', 'B-DATE', 'B-EVENT', 'B-FAC', 'B-GPE', 'B-LANGUAGE', 'B-LAW', 'B-LOC', 'B-MONEY', 'B-NORP', 'B-ORDINAL', 'B-ORG', 'B-PERCENT', 'B-PERSON', 'B-PRODUCT', 'B-QUANTITY', 'B-TIME', 'B-WORK_OF_ART', 'I-CARDINAL', 'I-DATE', 'I-EVENT', 'I-FAC', 'I-GPE', 'I-LANGUAGE', 'I-LAW', 'I-LOC', 'I-MONEY', 'I-NORP', 'I-ORDINAL', 'I-ORG', 'I-PERCENT', 'I-PERSON', 'I-PRODUCT', 'I-QUANTITY', 'I-TIME', 'I-WORK_OF_ART', 'O'] + ) + ), + } + ), + supervised_keys=None, + homepage="https://catalog.ldc.upenn.edu/LDC2013T19", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + urls_to_download = { + "train": f"{_URL}{_TRAINING_FILE}", + "dev": f"{_URL}{_DEV_FILE}", + "test": f"{_URL}{_TEST_FILE}", + } + downloaded_files = dl_manager.download_and_extract(urls_to_download) + + return [ + datasets.SplitGenerator(name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]}), + datasets.SplitGenerator(name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]}), + datasets.SplitGenerator(name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]}), + ] + + def _generate_examples(self, filepath): + logger.info("⏳ Generating examples from = %s", filepath) + with open(filepath, encoding="utf-8") as f: + guid = 0 + tokens = [] + ner_tags = [] + for line in f: + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + if tokens: + yield guid, { + "id": str(guid), + "tokens": tokens, + "ner_tags": ner_tags, + } + guid += 1 + tokens = [] + ner_tags = [] + else: + # OntoNotes 5.0 tokens are space separated + splits = line.split("\t") + tokens.append(splits[1]) + ner_tags.append(splits[-1].rstrip()) + # last example + # yield guid, { + # "id": str(guid), + # "tokens": tokens, + # "ner_tags": ner_tags, + # } \ No newline at end of file diff --git a/soft_prompt/tasks/ner/get_trainer.py b/soft_prompt/tasks/ner/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..e72e06c40220df2d426035a9912a7a621a608ded --- /dev/null +++ b/soft_prompt/tasks/ner/get_trainer.py @@ -0,0 +1,74 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from tasks.ner.dataset import NERDataset +from training.trainer_exp import ExponentialTrainer +from model.utils import get_model, TaskType +from tasks.utils import ADD_PREFIX_SPACE, USE_FAST + +logger = logging.getLogger(__name__) + + +def get_trainer(args): + model_args, data_args, training_args, qa_args = args + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + + model_type = AutoConfig.from_pretrained(model_args.model_name_or_path).model_type + + add_prefix_space = ADD_PREFIX_SPACE[model_type] + + use_fast = USE_FAST[model_type] + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=use_fast, + revision=model_args.model_revision, + add_prefix_space=add_prefix_space, + ) + + dataset = NERDataset(tokenizer, data_args, training_args) + + if training_args.do_train: + for index in random.sample(range(len(dataset.train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.") + + if data_args.dataset_name == "conll2003": + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label_to_id, + id2label={i: l for l, i in dataset.label_to_id.items()}, + revision=model_args.model_revision, + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label_to_id, + id2label={i: l for l, i in dataset.label_to_id.items()}, + revision=model_args.model_revision, + ) + + model = get_model(model_args, TaskType.TOKEN_CLASSIFICATION, config, fix_bert=True) + + trainer = ExponentialTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + predict_dataset=dataset.predict_dataset if training_args.do_predict else None, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + compute_metrics=dataset.compute_metrics, + test_key="f1" + ) + return trainer, dataset.predict_dataset diff --git a/soft_prompt/tasks/qa/dataset.py b/soft_prompt/tasks/qa/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..4b2c1aa1b95ccebe10cec677d83f8b42d27b282f --- /dev/null +++ b/soft_prompt/tasks/qa/dataset.py @@ -0,0 +1,182 @@ +import torch +from torch.utils.data.sampler import RandomSampler, SequentialSampler +from torch.utils.data import DataLoader +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.load import load_metric, load_dataset +from transformers import AutoTokenizer, DataCollatorForTokenClassification, BertConfig +from transformers import default_data_collator, EvalPrediction +import numpy as np +import logging + +from tasks.qa.utils_qa import postprocess_qa_predictions + +class SQuAD: + + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args, qa_args) -> None: + self.data_args = data_args + self.training_args = training_args + self.qa_args = qa_args + self.version_2 = data_args.dataset_name == "squad_v2" + + raw_datasets = load_dataset(data_args.dataset_name) + column_names = raw_datasets['train'].column_names + self.question_column_name = "question" + self.context_column_name = "context" + self.answer_column_name = "answers" + + self.tokenizer = tokenizer + + self.pad_on_right = tokenizer.padding_side == "right" # True + self.max_seq_len = 384 #data_args.max_seq_length + + if training_args.do_train: + self.train_dataset = raw_datasets['train'] + self.train_dataset = self.train_dataset.map( + self.prepare_train_dataset, + batched=True, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on train dataset", + ) + if data_args.max_train_samples is not None: + self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + self.eval_examples = raw_datasets['validation'] + if data_args.max_eval_samples is not None: + self.eval_examples = self.eval_examples.select(range(data_args.max_eval_samples)) + self.eval_dataset = self.eval_examples.map( + self.prepare_eval_dataset, + batched=True, + remove_columns=column_names, + load_from_cache_file=True, + desc="Running tokenizer on validation dataset", + ) + if data_args.max_eval_samples is not None: + self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) + + self.predict_dataset = None + + self.data_collator = default_data_collator + + self.metric = load_metric(data_args.dataset_name) + + def prepare_train_dataset(self, examples): + examples['question'] = [q.lstrip() for q in examples['question']] + + tokenized = self.tokenizer( + examples['question' if self.pad_on_right else 'context'], + examples['context' if self.pad_on_right else 'question'], + truncation='only_second' if self.pad_on_right else 'only_first', + max_length=self.max_seq_len, + stride=128, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + sample_maping = tokenized.pop("overflow_to_sample_mapping") + offset_mapping = tokenized.pop("offset_mapping") + tokenized["start_positions"] = [] + tokenized["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + input_ids = tokenized['input_ids'][i] + cls_index = input_ids.index(self.tokenizer.cls_token_id) + + sequence_ids = tokenized.sequence_ids(i) + sample_index = sample_maping[i] + answers = examples['answers'][sample_index] + + if len(answers['answer_start']) == 0: + tokenized["start_positions"].append(cls_index) + tokenized["end_positions"].append(cls_index) + else: + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + token_start_index = 0 + while sequence_ids[token_start_index] != (1 if self.pad_on_right else 0): + token_start_index += 1 + + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != (1 if self.pad_on_right else 0): + token_end_index -= 1 + + # Detect if the answer is out of the span + # (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): + tokenized["start_positions"].append(cls_index) + tokenized["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: + token_start_index += 1 + tokenized["start_positions"].append(token_start_index - 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized["end_positions"].append(token_end_index + 1) + + return tokenized + + def prepare_eval_dataset(self, examples): + # if self.version_2: + examples['question'] = [q.lstrip() for q in examples['question']] + + tokenized = self.tokenizer( + examples['question' if self.pad_on_right else 'context'], + examples['context' if self.pad_on_right else 'question'], + truncation='only_second' if self.pad_on_right else 'only_first', + max_length=self.max_seq_len, + stride=128, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding="max_length", + ) + + sample_mapping = tokenized.pop("overflow_to_sample_mapping") + tokenized["example_id"] = [] + + for i in range(len(tokenized["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized.sequence_ids(i) + context_index = 1 if self.pad_on_right else 0 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized["example_id"].append(examples["id"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized["offset_mapping"][i]) + ] + return tokenized + + def compute_metrics(self, p: EvalPrediction): + return self.metric.compute(predictions=p.predictions, references=p.label_ids) + + def post_processing_function(self, examples, features, predictions, stage='eval'): + predictions = postprocess_qa_predictions( + examples=examples, + features=features, + predictions=predictions, + version_2_with_negative=self.version_2, + n_best_size=self.qa_args.n_best_size, + max_answer_length=self.qa_args.max_answer_length, + null_score_diff_threshold=self.qa_args.null_score_diff_threshold, + output_dir=self.training_args.output_dir, + prefix=stage, + log_level=logging.INFO + ) + if self.version_2: # squad_v2 + formatted_predictions = [ + {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() + ] + else: + formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] + + references = [{"id": ex["id"], "answers": ex['answers']} for ex in examples] + return EvalPrediction(predictions=formatted_predictions, label_ids=references) diff --git a/soft_prompt/tasks/qa/get_trainer.py b/soft_prompt/tasks/qa/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d7a225557b4e19ebdf478615739011d5deca43 --- /dev/null +++ b/soft_prompt/tasks/qa/get_trainer.py @@ -0,0 +1,50 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from tasks.qa.dataset import SQuAD +from training.trainer_qa import QuestionAnsweringTrainer +from model.utils import get_model, TaskType + +logger = logging.getLogger(__name__) + +def get_trainer(args): + model_args, data_args, training_args, qa_args = args + + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=2, + revision=model_args.model_revision, + ) + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + revision=model_args.model_revision, + use_fast=True, + ) + + model = get_model(model_args, TaskType.QUESTION_ANSWERING, config, fix_bert=True) + + dataset = SQuAD(tokenizer, data_args, training_args, qa_args) + + trainer = QuestionAnsweringTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + eval_examples=dataset.eval_examples if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + post_process_function=dataset.post_processing_function, + compute_metrics=dataset.compute_metrics, + ) + + return trainer, dataset.predict_dataset + + diff --git a/soft_prompt/tasks/qa/utils_qa.py b/soft_prompt/tasks/qa/utils_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..159cbf2bab0a2bd8177a60184b51ab5b3c0a4e09 --- /dev/null +++ b/soft_prompt/tasks/qa/utils_qa.py @@ -0,0 +1,427 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Post-processing utilities for question answering. +""" +import collections +import json +import logging +import os +from typing import Optional, Tuple + +import numpy as np +from tqdm.auto import tqdm + + +logger = logging.getLogger(__name__) + + +def postprocess_qa_predictions( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + null_score_diff_threshold: float = 0.0, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + log_level: Optional[int] = logging.WARNING, +): + """ + Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the + original contexts. This is the base postprocessing functions for models that only return start and end logits. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0): + The threshold used to select the null answer: if the best answer has a score that is less than the score of + the null answer minus this threshold, the null answer is selected for this example (note that the score of + the null answer for an example giving several features is the minimum of the scores for the null answer on + each feature: all features must be aligned on the fact they `want` to predict a null answer). + + Only useful when :obj:`version_2_with_negative` is :obj:`True`. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) + """ + assert len(predictions) == 2, "`predictions` should be a tuple with two elements (start_logits, end_logits)." + all_start_logits, all_end_logits = predictions + + assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + if version_2_with_negative: + scores_diff_json = collections.OrderedDict() + + # Logging. + logger.setLevel(log_level) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_prediction = None + prelim_predictions = [] + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_logits = all_start_logits[feature_index] + end_logits = all_end_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction. + feature_null_score = start_logits[0] + end_logits[0] + if min_null_prediction is None or min_null_prediction["score"] > feature_null_score: + min_null_prediction = { + "offsets": (0, 0), + "score": feature_null_score, + "start_logit": start_logits[0], + "end_logit": end_logits[0], + } + + # Go through all possibilities for the `n_best_size` greater start and end logits. + start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() + end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() + for start_index in start_indexes: + for end_index in end_indexes: + # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond + # to part of the input_ids that are not in the context. + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length that is either < 0 or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_logits[start_index] + end_logits[end_index], + "start_logit": start_logits[start_index], + "end_logit": end_logits[end_index], + } + ) + + if version_2_with_negative: + # Add the minimum null prediction + prelim_predictions.append(min_null_prediction) + null_score = min_null_prediction["score"] + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Add back the minimum null prediction if it was removed because of its low score. + if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): + predictions.append(min_null_prediction) + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): + predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction. If the null answer is not possible, this is easy. + if not version_2_with_negative: + all_predictions[example["id"]] = predictions[0]["text"] + else: + # Otherwise we first need to find the best non-empty prediction. + i = 0 + while predictions[i]["text"] == "": + i += 1 + best_non_null_pred = predictions[i] + + # Then we compare to the null prediction using the threshold. + score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] + scores_diff_json[example["id"]] = float(score_diff) # To be JSON-serializable. + if score_diff > null_score_diff_threshold: + all_predictions[example["id"]] = "" + else: + all_predictions[example["id"]] = best_non_null_pred["text"] + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions + + +def postprocess_qa_predictions_with_beam_search( + examples, + features, + predictions: Tuple[np.ndarray, np.ndarray], + version_2_with_negative: bool = False, + n_best_size: int = 20, + max_answer_length: int = 30, + start_n_top: int = 5, + end_n_top: int = 5, + output_dir: Optional[str] = None, + prefix: Optional[str] = None, + log_level: Optional[int] = logging.WARNING, +): + """ + Post-processes the predictions of a question-answering model with beam search to convert them to answers that are substrings of the + original contexts. This is the postprocessing functions for models that return start and end logits, indices, as well as + cls token predictions. + + Args: + examples: The non-preprocessed dataset (see the main script for more information). + features: The processed dataset (see the main script for more information). + predictions (:obj:`Tuple[np.ndarray, np.ndarray]`): + The predictions of the model: two arrays containing the start logits and the end logits respectively. Its + first dimension must match the number of elements of :obj:`features`. + version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether or not the underlying dataset contains examples with no answers. + n_best_size (:obj:`int`, `optional`, defaults to 20): + The total number of n-best predictions to generate when looking for an answer. + max_answer_length (:obj:`int`, `optional`, defaults to 30): + The maximum length of an answer that can be generated. This is needed because the start and end predictions + are not conditioned on one another. + start_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top start logits too keep when searching for the :obj:`n_best_size` predictions. + end_n_top (:obj:`int`, `optional`, defaults to 5): + The number of top end logits too keep when searching for the :obj:`n_best_size` predictions. + output_dir (:obj:`str`, `optional`): + If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if + :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null + answers, are saved in `output_dir`. + prefix (:obj:`str`, `optional`): + If provided, the dictionaries mentioned above are saved with `prefix` added to their names. + log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``): + ``logging`` log level (e.g., ``logging.WARNING``) + """ + assert len(predictions) == 5, "`predictions` should be a tuple with five elements." + start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits = predictions + + assert len(predictions[0]) == len( + features + ), f"Got {len(predictions[0])} predicitions and {len(features)} features." + + # Build a map example to its corresponding features. + example_id_to_index = {k: i for i, k in enumerate(examples["id"])} + features_per_example = collections.defaultdict(list) + for i, feature in enumerate(features): + features_per_example[example_id_to_index[feature["example_id"]]].append(i) + + # The dictionaries we have to fill. + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() if version_2_with_negative else None + + # Logging. + logger.setLevel(log_level) + logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") + + # Let's loop over all the examples! + for example_index, example in enumerate(tqdm(examples)): + # Those are the indices of the features associated to the current example. + feature_indices = features_per_example[example_index] + + min_null_score = None + prelim_predictions = [] + + # Looping through all the features associated to the current example. + for feature_index in feature_indices: + # We grab the predictions of the model for this feature. + start_log_prob = start_top_log_probs[feature_index] + start_indexes = start_top_index[feature_index] + end_log_prob = end_top_log_probs[feature_index] + end_indexes = end_top_index[feature_index] + feature_null_score = cls_logits[feature_index] + # This is what will allow us to map some the positions in our logits to span of texts in the original + # context. + offset_mapping = features[feature_index]["offset_mapping"] + # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context + # available in the current feature. + token_is_max_context = features[feature_index].get("token_is_max_context", None) + + # Update minimum null prediction + if min_null_score is None or feature_null_score < min_null_score: + min_null_score = feature_null_score + + # Go through all possibilities for the `n_start_top`/`n_end_top` greater start and end logits. + for i in range(start_n_top): + for j in range(end_n_top): + start_index = int(start_indexes[i]) + j_index = i * end_n_top + j + end_index = int(end_indexes[j_index]) + # Don't consider out-of-scope answers (last part of the test should be unnecessary because of the + # p_mask but let's not take any risk) + if ( + start_index >= len(offset_mapping) + or end_index >= len(offset_mapping) + or offset_mapping[start_index] is None + or offset_mapping[end_index] is None + ): + continue + # Don't consider answers with a length negative or > max_answer_length. + if end_index < start_index or end_index - start_index + 1 > max_answer_length: + continue + # Don't consider answer that don't have the maximum context available (if such information is + # provided). + if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): + continue + prelim_predictions.append( + { + "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), + "score": start_log_prob[i] + end_log_prob[j_index], + "start_log_prob": start_log_prob[i], + "end_log_prob": end_log_prob[j_index], + } + ) + + # Only keep the best `n_best_size` predictions. + predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] + + # Use the offsets to gather the answer text in the original context. + context = example["context"] + for pred in predictions: + offsets = pred.pop("offsets") + pred["text"] = context[offsets[0] : offsets[1]] + + # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid + # failure. + if len(predictions) == 0: + predictions.insert(0, {"text": "", "start_logit": -1e-6, "end_logit": -1e-6, "score": -2e-6}) + + # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using + # the LogSumExp trick). + scores = np.array([pred.pop("score") for pred in predictions]) + exp_scores = np.exp(scores - np.max(scores)) + probs = exp_scores / exp_scores.sum() + + # Include the probabilities in our predictions. + for prob, pred in zip(probs, predictions): + pred["probability"] = prob + + # Pick the best prediction and set the probability for the null answer. + all_predictions[example["id"]] = predictions[0]["text"] + if version_2_with_negative: + scores_diff_json[example["id"]] = float(min_null_score) + + # Make `predictions` JSON-serializable by casting np.float back to float. + all_nbest_json[example["id"]] = [ + {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} + for pred in predictions + ] + + # If we have an output_dir, let's save all those dicts. + if output_dir is not None: + assert os.path.isdir(output_dir), f"{output_dir} is not a directory." + + prediction_file = os.path.join( + output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json" + ) + nbest_file = os.path.join( + output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json" + ) + if version_2_with_negative: + null_odds_file = os.path.join( + output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json" + ) + + logger.info(f"Saving predictions to {prediction_file}.") + with open(prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + logger.info(f"Saving nbest_preds to {nbest_file}.") + with open(nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + if version_2_with_negative: + logger.info(f"Saving null_odds to {null_odds_file}.") + with open(null_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + return all_predictions, scores_diff_json \ No newline at end of file diff --git a/soft_prompt/tasks/srl/dataset.py b/soft_prompt/tasks/srl/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a5dc9ab4ac76b0709aeb89652aee350bbaf10d --- /dev/null +++ b/soft_prompt/tasks/srl/dataset.py @@ -0,0 +1,143 @@ +from typing import OrderedDict +import torch +from torch.utils import data +from torch.utils.data import Dataset +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.load import load_dataset, load_metric +from transformers import AutoTokenizer, DataCollatorForTokenClassification, AutoConfig +import numpy as np + +class SRLDataset(Dataset): + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: + super().__init__() + raw_datasets = load_dataset(f'tasks/srl/datasets/{data_args.dataset_name}.py') + self.tokenizer = tokenizer + + if training_args.do_train: + column_names = raw_datasets["train"].column_names + features = raw_datasets["train"].features + else: + column_names = raw_datasets["validation"].column_names + features = raw_datasets["validation"].features + + self.label_column_name = f"tags" + self.label_list = features[self.label_column_name].feature.names + self.label_to_id = {l: i for i, l in enumerate(self.label_list)} + self.num_labels = len(self.label_list) + + if training_args.do_train: + train_dataset = raw_datasets['train'] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + self.train_dataset = train_dataset.map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on train dataset", + ) + if training_args.do_eval: + eval_dataset = raw_datasets['validation'] + if data_args.max_eval_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) + self.eval_dataset = eval_dataset.map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on validation dataset", + ) + + if training_args.do_predict: + if data_args.dataset_name == "conll2005": + self.predict_dataset = OrderedDict() + self.predict_dataset['wsj'] = raw_datasets['test_wsj'].map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on WSJ test dataset", + ) + + self.predict_dataset['brown'] = raw_datasets['test_brown'].map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on Brown test dataset", + ) + else: + self.predict_dataset = raw_datasets['test_wsj'].map( + self.tokenize_and_align_labels, + batched=True, + load_from_cache_file=True, + desc="Running tokenizer on WSJ test dataset", + ) + + self.data_collator = DataCollatorForTokenClassification(self.tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None) + self.metric = load_metric("seqeval") + + + def compute_metrics(self, p): + predictions, labels = p + predictions = np.argmax(predictions, axis=2) + + # Remove ignored index (special tokens) + true_predictions = [ + [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + true_labels = [ + [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100] + for prediction, label in zip(predictions, labels) + ] + + results = self.metric.compute(predictions=true_predictions, references=true_labels) + return { + "precision": results["overall_precision"], + "recall": results["overall_recall"], + "f1": results["overall_f1"], + "accuracy": results["overall_accuracy"], + } + + def tokenize_and_align_labels(self, examples): + for i, tokens in enumerate(examples['tokens']): + examples['tokens'][i] = tokens + ["[SEP]"] + [tokens[int(examples['index'][i])]] + + tokenized_inputs = self.tokenizer( + examples['tokens'], + padding=False, + truncation=True, + # We use this argument because the texts in our dataset are lists of words (with a label for each word). + is_split_into_words=True, + ) + + # print(tokenized_inputs['input_ids'][0]) + + labels = [] + for i, label in enumerate(examples['tags']): + word_ids = [None] + for j, word in enumerate(examples['tokens'][i][:-2]): + token = self.tokenizer.encode(word, add_special_tokens=False) + word_ids += [j] * len(token) + word_ids += [None] + verb = examples['tokens'][i][int(examples['index'][i])] + word_ids += [None] * len(self.tokenizer.encode(verb, add_special_tokens=False)) + word_ids += [None] + + # word_ids = tokenized_inputs.word_ids(batch_index=i) + previous_word_idx = None + label_ids = [] + for word_idx in word_ids: + # Special tokens have a word id that is None. We set the label to -100 so they are automatically + # ignored in the loss function. + if word_idx is None: + label_ids.append(-100) + # We set the label for the first token of each word. + elif word_idx != previous_word_idx: + label_ids.append(label[word_idx]) + # For the other tokens in a word, we set the label to either the current label or -100, depending on + # the label_all_tokens flag. + else: + label_ids.append(-100) + previous_word_idx = word_idx + + labels.append(label_ids) + tokenized_inputs["labels"] = labels + return tokenized_inputs \ No newline at end of file diff --git a/soft_prompt/tasks/srl/datasets/conll2005.py b/soft_prompt/tasks/srl/datasets/conll2005.py new file mode 100644 index 0000000000000000000000000000000000000000..3304d670d9fa6f689f17e34e6c15c2595915ba29 --- /dev/null +++ b/soft_prompt/tasks/srl/datasets/conll2005.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition""" + +import datasets + + +logger = datasets.logging.get_logger(__name__) + + +_CITATION = """\ +@inproceedings{tjong-kim-sang-de-meulder-2003-introduction, + title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition", + author = "Tjong Kim Sang, Erik F. and + De Meulder, Fien", + booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003", + year = "2003", + url = "https://www.aclweb.org/anthology/W03-0419", + pages = "142--147", +} +""" + +_DESCRIPTION = """\ +The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on +four types of named entities: persons, locations, organizations and names of miscellaneous entities that do +not belong to the previous three groups. +The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on +a separate line and there is an empty line after each sentence. The first item on each line is a word, the second +a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags +and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only +if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag +B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2 +tagging scheme, whereas the original dataset uses IOB1. +For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419 +""" + +_URL = "../../../data/CoNLL05/" +_TRAINING_FILE = "conll05.train.txt" +_DEV_FILE = "conll05.devel.txt" +_TEST_WSJ_FILE = "conll05.test.wsj.txt" +_TEST_BROWN_FILE = "conll05.test.brown.txt" + + +class Conll2005Config(datasets.BuilderConfig): + """BuilderConfig for Conll2003""" + + def __init__(self, **kwargs): + """BuilderConfig forConll2005. + Args: + **kwargs: keyword arguments forwarded to super. + """ + super(Conll2005Config, self).__init__(**kwargs) + + +class Conll2005(datasets.GeneratorBasedBuilder): + """Conll2003 dataset.""" + + BUILDER_CONFIGS = [ + Conll2005Config(name="conll2005", version=datasets.Version("1.0.0"), description="Conll2005 dataset"), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "id": datasets.Value("string"), + "index": datasets.Value("string"), + "tokens": datasets.Sequence(datasets.Value("string")), + "tags": datasets.Sequence( + datasets.features.ClassLabel( + names=['B-C-AM-TMP', 'B-C-AM-DIR', 'B-C-A2', 'B-R-AM-EXT', 'B-C-A0', 'I-AM-NEG', 'I-AM-ADV', 'B-C-V', 'B-C-AM-MNR', 'B-R-A3', 'I-AM-TM', 'B-V', 'B-R-A4', 'B-A5', 'I-A4', 'I-R-AM-LOC', 'I-C-A1', 'B-R-AA', 'I-C-A0', 'B-C-AM-EXT', 'I-C-AM-DIS', 'I-C-A5', 'B-A0', 'B-C-A4', 'B-C-AM-CAU', 'B-C-AM-NEG', 'B-AM-NEG', 'I-AM-MNR', 'I-R-A2', 'I-R-AM-TMP', 'B-AM', 'I-R-AM-PNC', 'B-AM-LOC', 'B-AM-REC', 'B-A2', 'I-AM-EXT', 'I-V', 'B-A3', 'B-A4', 'B-R-A0', 'I-AM-MOD', 'I-C-AM-CAU', 'B-R-AM-CAU', 'B-A1', 'B-R-AM-TMP', 'I-R-AM-EXT', 'B-C-AM-ADV', 'B-AM-ADV', 'B-R-A2', 'B-AM-CAU', 'B-R-AM-DIR', 'I-A5', 'B-C-AM-DIS', 'I-C-AM-MNR', 'B-AM-PNC', 'I-C-AM-LOC', 'I-R-A3', 'I-R-AM-ADV', 'I-A0', 'B-AM-EXT', 'B-R-AM-PNC', 'I-AM-DIS', 'I-AM-REC', 'B-C-AM-LOC', 'B-R-AM-ADV', 'I-AM', 'I-AM-CAU', 'I-AM-TMP', 'I-A1', 'I-C-A4', 'B-R-AM-LOC', 'I-C-A2', 'B-C-A5', 'O', 'B-R-AM-MNR', 'I-C-A3', 'I-R-AM-DIR', 'I-AM-PRD', 'B-AM-TM', 'I-A2', 'I-AA', 'I-AM-LOC', 'I-AM-PNC', 'B-AM-MOD', 'B-AM-DIR', 'B-R-A1', 'B-AM-TMP', 'B-AM-MNR', 'I-R-A0', 'B-AM-PRD', 'I-AM-DIR', 'B-AM-DIS', 'I-C-AM-ADV', 'I-R-A1', 'B-C-A3', 'I-R-AM-MNR', 'I-R-A4', 'I-C-AM-PNC', 'I-C-AM-TMP', 'I-C-V', 'I-A3', 'I-C-AM-EXT', 'B-C-A1', 'B-AA', 'I-C-AM-DIR', 'B-C-AM-PNC'] + ) + ), + } + ), + supervised_keys=None, + homepage="https://www.aclweb.org/anthology/W03-0419/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + urls_to_download = { + "train": f"{_URL}{_TRAINING_FILE}", + "dev": f"{_URL}{_DEV_FILE}", + "test_wsj": f"{_URL}{_TEST_WSJ_FILE}", + "test_brown": f"{_URL}{_TEST_BROWN_FILE}" + } + downloaded_files = dl_manager.download_and_extract(urls_to_download) + + return [ + datasets.SplitGenerator(name="train", gen_kwargs={"filepath": downloaded_files["train"]}), + datasets.SplitGenerator(name="validation", gen_kwargs={"filepath": downloaded_files["dev"]}), + datasets.SplitGenerator(name="test_wsj", gen_kwargs={"filepath": downloaded_files["test_wsj"]}), + datasets.SplitGenerator(name="test_brown", gen_kwargs={"filepath": downloaded_files["test_brown"]}), + ] + + def _generate_examples(self, filepath): + logger.info("⏳ Generating examples from = %s", filepath) + with open(filepath, encoding="utf-8") as f: + guid = 0 + for line in f: + if line != '': + index = line.split()[0] + + text = ' '.join(line.split()[1:]).strip() + tokens = text.split("|||")[0].split() + labels = text.split("|||")[1].split() + yield guid, { + "id": str(guid), + "index": index, + "tokens": tokens, + "tags": labels + } + + guid += 1 \ No newline at end of file diff --git a/soft_prompt/tasks/srl/datasets/conll2012.py b/soft_prompt/tasks/srl/datasets/conll2012.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac51f992a203003aa91cf69dc9c3230436fd4a2 --- /dev/null +++ b/soft_prompt/tasks/srl/datasets/conll2012.py @@ -0,0 +1,152 @@ +# coding=utf-8 +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition""" + +import datasets + + +logger = datasets.logging.get_logger(__name__) + + +_CITATION = """\ +@inproceedings{tjong-kim-sang-de-meulder-2003-introduction, + title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition", + author = "Tjong Kim Sang, Erik F. and + De Meulder, Fien", + booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003", + year = "2003", + url = "https://www.aclweb.org/anthology/W03-0419", + pages = "142--147", +} +""" + +_DESCRIPTION = """\ +The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on +four types of named entities: persons, locations, organizations and names of miscellaneous entities that do +not belong to the previous three groups. +The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on +a separate line and there is an empty line after each sentence. The first item on each line is a word, the second +a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags +and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only +if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag +B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2 +tagging scheme, whereas the original dataset uses IOB1. +For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419 +""" + +_URL = "../../../data/CoNLL12/" +_TRAINING_FILE = "conll2012.train.txt" +_DEV_FILE = "conll2012.devel.txt" +_TEST_WSJ_FILE = "conll2012.test.txt" +# _TEST_BROWN_FILE = "conll.test.brown.txt" +CONLL12_LABELS = ['B-ARG0', 'B-ARGM-MNR', 'B-V', 'B-ARG1', 'B-ARG2', 'I-ARG2', 'O', 'I-ARG1', 'B-ARGM-ADV', + 'B-ARGM-LOC', 'I-ARGM-LOC', 'I-ARG0', 'B-ARGM-TMP', 'I-ARGM-TMP', 'B-ARGM-PRP', + 'I-ARGM-PRP', 'B-ARGM-PRD', 'I-ARGM-PRD', 'B-R-ARGM-TMP', 'B-ARGM-DIR', 'I-ARGM-DIR', + 'B-ARGM-DIS', 'B-ARGM-MOD', 'I-ARGM-ADV', 'I-ARGM-DIS', 'B-R-ARGM-LOC', 'B-ARG4', + 'I-ARG4', 'B-R-ARG1', 'B-R-ARG0', 'I-R-ARG0', 'B-ARG3', 'B-ARGM-NEG', 'B-ARGM-CAU', + 'I-ARGM-MNR', 'I-R-ARG1', 'B-C-ARG1', 'I-C-ARG1', 'B-ARGM-EXT', 'I-ARGM-EXT', 'I-ARGM-CAU', + 'I-ARG3', 'B-C-ARGM-ADV', 'I-C-ARGM-ADV', 'B-ARGM-LVB', 'B-ARGM-REC', 'B-R-ARG3', + 'B-R-ARG2', 'B-C-ARG0', 'I-C-ARG0', 'B-ARGM-ADJ', 'B-C-ARG2', 'I-C-ARG2', 'B-R-ARGM-CAU', + 'B-R-ARGM-DIR', 'B-ARGM-GOL', 'I-ARGM-GOL', 'B-ARGM-DSP', 'I-ARGM-ADJ', 'I-R-ARG2', + 'I-ARGM-NEG', 'B-ARGM-PRR', 'B-R-ARGM-ADV', 'I-R-ARGM-ADV', 'I-R-ARGM-LOC', 'B-ARGA', + 'B-R-ARGM-MNR', 'I-R-ARGM-MNR', 'B-ARGM-COM', 'I-ARGM-COM', 'B-ARGM-PRX', 'I-ARGM-REC', + 'B-R-ARG4', 'B-C-ARGM-LOC', 'I-C-ARGM-LOC', 'I-R-ARGM-DIR', 'I-ARGA', 'B-C-ARGM-TMP', + 'I-C-ARGM-TMP', 'B-C-ARGM-CAU', 'I-C-ARGM-CAU', 'B-R-ARGM-PRD', 'I-R-ARGM-PRD', + 'I-R-ARG3', 'B-C-ARG4', 'I-C-ARG4', 'B-ARGM-PNC', 'I-ARGM-PNC', 'B-ARG5', 'I-ARG5', + 'B-C-ARGM-PRP', 'I-C-ARGM-PRP', 'B-C-ARGM-MNR', 'I-C-ARGM-MNR', 'I-R-ARGM-TMP', + 'B-R-ARG5', 'I-ARGM-DSP', 'B-C-ARGM-DSP', 'I-C-ARGM-DSP', 'B-C-ARG3', 'I-C-ARG3', + 'B-R-ARGM-COM', 'I-R-ARGM-COM', 'B-R-ARGM-PRP', 'I-R-ARGM-PRP', 'I-R-ARGM-CAU', + 'B-R-ARGM-GOL', 'I-R-ARGM-GOL', 'B-R-ARGM-EXT', 'I-R-ARGM-EXT', 'I-R-ARG4', + 'B-C-ARGM-EXT', 'I-C-ARGM-EXT', 'I-ARGM-MOD', 'B-C-ARGM-MOD', 'I-C-ARGM-MOD'] + + +class Conll2005Config(datasets.BuilderConfig): + """BuilderConfig for Conll2003""" + + def __init__(self, **kwargs): + """BuilderConfig forConll2005. + Args: + **kwargs: keyword arguments forwarded to super. + """ + super(Conll2005Config, self).__init__(**kwargs) + + +class Conll2005(datasets.GeneratorBasedBuilder): + """Conll2003 dataset.""" + + BUILDER_CONFIGS = [ + Conll2005Config(name="conll2012", version=datasets.Version("1.0.0"), description="Conll2012 dataset"), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "id": datasets.Value("string"), + "index": datasets.Value("string"), + "tokens": datasets.Sequence(datasets.Value("string")), + "tags": datasets.Sequence( + datasets.features.ClassLabel( + names=CONLL12_LABELS + ) + ), + } + ), + supervised_keys=None, + homepage="https://www.aclweb.org/anthology/W03-0419/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + urls_to_download = { + "train": f"{_URL}{_TRAINING_FILE}", + "dev": f"{_URL}{_DEV_FILE}", + "test_wsj": f"{_URL}{_TEST_WSJ_FILE}", + # "test_brown": f"{_URL}{_TEST_BROWN_FILE}" + } + downloaded_files = dl_manager.download_and_extract(urls_to_download) + + return [ + datasets.SplitGenerator(name="train", gen_kwargs={"filepath": downloaded_files["train"]}), + datasets.SplitGenerator(name="validation", gen_kwargs={"filepath": downloaded_files["dev"]}), + datasets.SplitGenerator(name="test_wsj", gen_kwargs={"filepath": downloaded_files["test_wsj"]}), + # datasets.SplitGenerator(name="test_brown", gen_kwargs={"filepath": downloaded_files["test_brown"]}), + ] + + def _generate_examples(self, filepath): + logger.info("⏳ Generating examples from = %s", filepath) + with open(filepath, encoding="utf-8") as f: + guid = 0 + for line in f: + if line != '': + if line.split() == []: + continue + index = line.split()[0] + + text = ' '.join(line.split()[1:]).strip() + tokens = text.split("|||")[0].split() + labels = text.split("|||")[1].split() + yield guid, { + "id": str(guid), + "index": index, + "tokens": tokens, + "tags": labels + } + + guid += 1 \ No newline at end of file diff --git a/soft_prompt/tasks/srl/get_trainer.py b/soft_prompt/tasks/srl/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8ccd2b95cc66fffe1de735a15b50813846af9252 --- /dev/null +++ b/soft_prompt/tasks/srl/get_trainer.py @@ -0,0 +1,61 @@ +import logging +import os +import random +import sys + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from tasks.srl.dataset import SRLDataset +from training.trainer_exp import ExponentialTrainer +from model.utils import get_model, TaskType +from tasks.utils import ADD_PREFIX_SPACE, USE_FAST + +logger = logging.getLogger(__name__) + +def get_trainer(args): + model_args, data_args, training_args, _ = args + + model_type = AutoConfig.from_pretrained(model_args.model_name_or_path).model_type + + add_prefix_space = ADD_PREFIX_SPACE[model_type] + + use_fast = USE_FAST[model_type] + + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=use_fast, + revision=model_args.model_revision, + add_prefix_space=add_prefix_space, + ) + + dataset = SRLDataset(tokenizer, data_args, training_args) + + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + revision=model_args.model_revision, + ) + + if training_args.do_train: + for index in random.sample(range(len(dataset.train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.") + + model = get_model(model_args, TaskType.TOKEN_CLASSIFICATION, config, fix_bert=False) + + + trainer = ExponentialTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + predict_dataset=dataset.predict_dataset if training_args.do_predict else None, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + compute_metrics=dataset.compute_metrics, + test_key="f1" + ) + + return trainer, dataset.predict_dataset \ No newline at end of file diff --git a/soft_prompt/tasks/superglue/dataset.py b/soft_prompt/tasks/superglue/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1e655736f8cd4aaa6fc98aba3f8896476bf7e138 --- /dev/null +++ b/soft_prompt/tasks/superglue/dataset.py @@ -0,0 +1,257 @@ +import os.path + +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + default_data_collator, +) +import hashlib, torch +import numpy as np +import logging +from collections import defaultdict + +task_to_keys = { + "boolq": ("question", "passage"), + "cb": ("premise", "hypothesis"), + "rte": ("premise", "hypothesis"), + "wic": ("processed_sentence1", None), + "wsc": ("span2_word_text", "span1_text"), + "copa": (None, None), + "record": (None, None), + "multirc": ("paragraph", "question_answer") +} + +logger = logging.getLogger(__name__) + + +class SuperGlueDataset(): + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: + super().__init__() + raw_datasets = load_dataset("super_glue", data_args.dataset_name) + self.tokenizer = tokenizer + self.data_args = data_args + + self.multiple_choice = data_args.dataset_name in ["copa"] + + if data_args.dataset_name == "record": + self.num_labels = 2 + self.label_list = ["0", "1"] + elif not self.multiple_choice: + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + else: + self.num_labels = 1 + + # Preprocessing the raw_datasets + self.sentence1_key, self.sentence2_key = task_to_keys[data_args.dataset_name] + + # Padding strategy + if data_args.pad_to_max_length: + self.padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + self.padding = False + + if not self.multiple_choice: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + print(f"{self.label2id}") + print(f"{self.id2label}") + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if data_args.dataset_name == "record": + digest = hashlib.md5(f"record_{tokenizer.name_or_path}".encode("utf-8")).hexdigest()[:16] # 16 byte binary + path = raw_datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"record-{digest}.arrow") + if not os.path.exists(path): + print(f"-> path not found!:{path}") + raw_datasets = raw_datasets.map( + self.record_preprocess_function, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on dataset", + ) + data = {"raw_datasets": raw_datasets} + torch.save(data, path) + raw_datasets = torch.load(path)["raw_datasets"] + else: + raw_datasets = raw_datasets.map( + self.preprocess_function, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on dataset", + ) + + if training_args.do_train: + self.train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + self.eval_dataset = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) + + if training_args.do_predict or data_args.dataset_name is not None or data_args.test_file is not None: + self.predict_dataset = raw_datasets["test"] + if data_args.max_predict_samples is not None: + self.predict_dataset = self.predict_dataset.select(range(data_args.max_predict_samples)) + + self.metric = load_metric("super_glue", data_args.dataset_name) + + if data_args.pad_to_max_length: + self.data_collator = default_data_collator + elif training_args.fp16: + self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + + self.test_key = "accuracy" if data_args.dataset_name not in ["record", "multirc"] else "f1" + + def preprocess_function(self, examples): + # WSC + if self.data_args.dataset_name == "wsc": + examples["span2_word_text"] = [] + for text, span2_index, span2_word in zip(examples["text"], examples["span2_index"], examples["span2_text"]): + if self.data_args.template_id == 0: + examples["span2_word_text"].append(span2_word + ": " + text) + elif self.data_args.template_id == 1: + words_a = text.split() + words_a[span2_index] = "*" + words_a[span2_index] + "*" + examples["span2_word_text"].append(' '.join(words_a)) + + # WiC + if self.data_args.dataset_name == "wic": + examples["processed_sentence1"] = [] + if self.data_args.template_id == 1: + self.sentence2_key = "processed_sentence2" + examples["processed_sentence2"] = [] + for sentence1, sentence2, word, start1, end1, start2, end2 in zip(examples["sentence1"], + examples["sentence2"], examples["word"], + examples["start1"], examples["end1"], + examples["start2"], examples["end2"]): + if self.data_args.template_id == 0: # ROBERTA + examples["processed_sentence1"].append( + f"{sentence1} {sentence2} Does {word} have the same meaning in both sentences?") + elif self.data_args.template_id == 1: # BERT + examples["processed_sentence1"].append(word + ": " + sentence1) + examples["processed_sentence2"].append(word + ": " + sentence2) + + # MultiRC + if self.data_args.dataset_name == "multirc": + examples["question_answer"] = [] + for question, asnwer in zip(examples["question"], examples["answer"]): + examples["question_answer"].append(f"{question} {asnwer}") + + # COPA + if self.data_args.dataset_name == "copa": + examples["text_a"] = [] + for premise, question in zip(examples["premise"], examples["question"]): + joiner = "because" if question == "cause" else "so" + text_a = f"{premise} {joiner}" + examples["text_a"].append(text_a) + + result1 = self.tokenizer(examples["text_a"], examples["choice1"], padding=self.padding, + max_length=self.max_seq_length, truncation=True) + result2 = self.tokenizer(examples["text_a"], examples["choice2"], padding=self.padding, + max_length=self.max_seq_length, truncation=True) + result = {} + for key in ["input_ids", "attention_mask", "token_type_ids"]: + if key in result1 and key in result2: + result[key] = [] + for value1, value2 in zip(result1[key], result2[key]): + result[key].append([value1, value2]) + return result + + args = ( + (examples[self.sentence1_key],) if self.sentence2_key is None else ( + examples[self.sentence1_key], examples[self.sentence2_key]) + ) + result = self.tokenizer(*args, padding=self.padding, max_length=self.max_seq_length, truncation=True) + + return result + + def compute_metrics(self, p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.argmax(preds, axis=1) + + if self.data_args.dataset_name == "record": + return self.reocrd_compute_metrics(p) + + if self.data_args.dataset_name == "multirc": + from sklearn.metrics import f1_score + return {"f1": f1_score(preds, p.label_ids)} + + if self.data_args.dataset_name is not None: + result = self.metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif self.is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + def reocrd_compute_metrics(self, p: EvalPrediction): + from .utils import f1_score, exact_match_score, metric_max_over_ground_truths + probs = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + examples = self.eval_dataset + qid2pred = defaultdict(list) + qid2ans = {} + for prob, example in zip(probs, examples): + qid = example['question_id'] + qid2pred[qid].append((prob[1], example['entity'])) + if qid not in qid2ans: + qid2ans[qid] = example['answers'] + n_correct, n_total = 0, 0 + f1, em = 0, 0 + for qid in qid2pred: + preds = sorted(qid2pred[qid], reverse=True) + entity = preds[0][1] + n_total += 1 + n_correct += (entity in qid2ans[qid]) + f1 += metric_max_over_ground_truths(f1_score, entity, qid2ans[qid]) + em += metric_max_over_ground_truths(exact_match_score, entity, qid2ans[qid]) + acc = n_correct / n_total + f1 = f1 / n_total + em = em / n_total + return {'f1': f1, 'exact_match': em} + + def record_preprocess_function(self, examples, split="train"): + results = { + "index": list(), + "question_id": list(), + "input_ids": list(), + "attention_mask": list(), + #"token_type_ids": list(), + "label": list(), + "entity": list(), + "answers": list() + } + for idx, passage in enumerate(examples["passage"]): + query, entities, answers = examples["query"][idx], examples["entities"][idx], examples["answers"][idx] + index = examples["idx"][idx] + passage = passage.replace("@highlight\n", "- ").replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "") + + for ent_idx, ent in enumerate(entities): + question = query.replace("@placeholder", ent).replace(self.tokenizer.prompt_token, "").replace(self.tokenizer.skey_token, "").replace(self.tokenizer.predict_token, "") + result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, + truncation=True) + label = 1 if ent in answers else 0 + + results["input_ids"].append(result["input_ids"]) + results["attention_mask"].append(result["attention_mask"]) + #if "token_type_ids" in result.keys(): results["token_type_ids"].append(result["token_type_ids"]) + results["label"].append(label) + results["index"].append(index) + results["question_id"].append(index["query"]) + results["entity"].append(ent) + results["answers"].append(answers) + + return results \ No newline at end of file diff --git a/soft_prompt/tasks/superglue/dataset_record.py b/soft_prompt/tasks/superglue/dataset_record.py new file mode 100644 index 0000000000000000000000000000000000000000..a63898418bef33db38a17afb565f42a53675e2d0 --- /dev/null +++ b/soft_prompt/tasks/superglue/dataset_record.py @@ -0,0 +1,251 @@ +import torch +from torch.utils import data +from torch.utils.data import Dataset +from datasets.arrow_dataset import Dataset as HFDataset +from datasets.load import load_dataset, load_metric +from transformers import ( + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + default_data_collator, + DataCollatorForLanguageModeling +) +import random +import numpy as np +import logging + +from .dataset import SuperGlueDataset + +from dataclasses import dataclass +from transformers.data.data_collator import DataCollatorMixin +from transformers.file_utils import PaddingStrategy +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union + +logger = logging.getLogger(__name__) + +@dataclass +class DataCollatorForMultipleChoice(DataCollatorMixin): + tokenizer: PreTrainedTokenizerBase + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + + def torch_call(self, features): + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None + batch = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + # Conversion to tensors will fail if we have labels as they are not of the same length yet. + return_tensors="pt" if labels is None else None, + ) + + if labels is None: + return batch + + sequence_length = torch.tensor(batch["input_ids"]).shape[1] + padding_side = self.tokenizer.padding_side + if padding_side == "right": + batch[label_name] = [ + list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels + ] + else: + batch[label_name] = [ + [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels + ] + + batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} + print(batch) + input_list = [sample['input_ids'] for sample in batch] + + choice_nums = list(map(len, input_list)) + max_choice_num = max(choice_nums) + + def pad_choice_dim(data, choice_num): + if len(data) < choice_num: + data = np.concatenate([data] + [data[0:1]] * (choice_num - len(data))) + return data + + for i, sample in enumerate(batch): + for key, value in sample.items(): + if key != 'label': + sample[key] = pad_choice_dim(value, max_choice_num) + else: + sample[key] = value + # sample['loss_mask'] = np.array([1] * choice_nums[i] + [0] * (max_choice_num - choice_nums[i]), + # dtype=np.int64) + + return batch + + +class SuperGlueDatasetForRecord(SuperGlueDataset): + def __init__(self, tokenizer: AutoTokenizer, data_args, training_args) -> None: + raw_datasets = load_dataset("super_glue", data_args.dataset_name) + self.tokenizer = tokenizer + self.data_args = data_args + #labels + self.multiple_choice = data_args.dataset_name in ["copa", "record"] + + if not self.multiple_choice: + self.label_list = raw_datasets["train"].features["label"].names + self.num_labels = len(self.label_list) + else: + self.num_labels = 1 + + # Padding strategy + if data_args.pad_to_max_length: + self.padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + self.padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + self.label_to_id = None + + if self.label_to_id is not None: + self.label2id = self.label_to_id + self.id2label = {id: label for label, id in self.label2id.items()} + elif not self.multiple_choice: + self.label2id = {l: i for i, l in enumerate(self.label_list)} + self.id2label = {id: label for label, id in self.label2id.items()} + + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warning( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + self.max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + if training_args.do_train: + self.train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + self.train_dataset = self.train_dataset.select(range(data_args.max_train_samples)) + + self.train_dataset = self.train_dataset.map( + self.prepare_train_dataset, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on train dataset", + ) + + if training_args.do_eval: + self.eval_dataset = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + self.eval_dataset = self.eval_dataset.select(range(data_args.max_eval_samples)) + + self.eval_dataset = self.eval_dataset.map( + self.prepare_eval_dataset, + batched=True, + load_from_cache_file=not data_args.overwrite_cache, + remove_columns=raw_datasets["train"].column_names, + desc="Running tokenizer on validation dataset", + ) + + self.metric = load_metric("super_glue", data_args.dataset_name) + + self.data_collator = DataCollatorForMultipleChoice(tokenizer) + # if data_args.pad_to_max_length: + # self.data_collator = default_data_collator + # elif training_args.fp16: + # self.data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + def preprocess_function(self, examples): + results = { + "input_ids": list(), + "attention_mask": list(), + "token_type_ids": list(), + "label": list() + } + for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): + passage = passage.replace("@highlight\n", "- ") + + input_ids = [] + attention_mask = [] + token_type_ids = [] + + for _, ent in enumerate(entities): + question = query.replace("@placeholder", ent) + result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) + + input_ids.append(result["input_ids"]) + attention_mask.append(result["attention_mask"]) + if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) + label = 1 if ent in answers else 0 + + result["label"].append() + + return results + + + def prepare_train_dataset(self, examples, max_train_candidates_per_question=10): + entity_shuffler = random.Random(44) + results = { + "input_ids": list(), + "attention_mask": list(), + "token_type_ids": list(), + "label": list() + } + for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): + passage = passage.replace("@highlight\n", "- ") + + for answer in answers: + input_ids = [] + attention_mask = [] + token_type_ids = [] + candidates = [ent for ent in entities if ent not in answers] + # if len(candidates) < max_train_candidates_per_question - 1: + # continue + if len(candidates) > max_train_candidates_per_question - 1: + entity_shuffler.shuffle(candidates) + candidates = candidates[:max_train_candidates_per_question - 1] + candidates = [answer] + candidates + + for ent in candidates: + question = query.replace("@placeholder", ent) + result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) + input_ids.append(result["input_ids"]) + attention_mask.append(result["attention_mask"]) + if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) + + results["input_ids"].append(input_ids) + results["attention_mask"].append(attention_mask) + if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids) + results["label"].append(0) + + return results + + + def prepare_eval_dataset(self, examples): + + results = { + "input_ids": list(), + "attention_mask": list(), + "token_type_ids": list(), + "label": list() + } + for passage, query, entities, answers in zip(examples["passage"], examples["query"], examples["entities"], examples["answers"]): + passage = passage.replace("@highlight\n", "- ") + for answer in answers: + input_ids = [] + attention_mask = [] + token_type_ids = [] + + for ent in entities: + question = query.replace("@placeholder", ent) + result = self.tokenizer(passage, question, padding=self.padding, max_length=self.max_seq_length, truncation=True) + input_ids.append(result["input_ids"]) + attention_mask.append(result["attention_mask"]) + if "token_type_ids" in result: token_type_ids.append(result["token_type_ids"]) + + results["input_ids"].append(input_ids) + results["attention_mask"].append(attention_mask) + if len(token_type_ids) > 0: results["token_type_ids"].append(token_type_ids) + results["label"].append(0) + + return results diff --git a/soft_prompt/tasks/superglue/get_trainer.py b/soft_prompt/tasks/superglue/get_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7368b13324d6efd40bf4c17e5a3a22f5b1977766 --- /dev/null +++ b/soft_prompt/tasks/superglue/get_trainer.py @@ -0,0 +1,106 @@ +import logging +import os +import random +import torch + +from transformers import ( + AutoConfig, + AutoTokenizer, +) + +from model.utils import get_model, TaskType +from tasks.superglue.dataset import SuperGlueDataset +from training import BaseTrainer +from training.trainer_exp import ExponentialTrainer +from tasks import utils +from .utils import load_from_cache + +logger = logging.getLogger(__name__) + +def get_trainer(args): + model_args, data_args, training_args, _ = args + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + model_args.model_name_or_path = load_from_cache(model_args.model_name_or_path) + + if "llama" in model_args.model_name_or_path: + from transformers import LlamaTokenizer + model_path = f'openlm-research/{model_args.model_name_or_path}' + tokenizer = LlamaTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.mask_token = tokenizer.unk_token + tokenizer.mask_token_id = tokenizer.unk_token_id + elif 'gpt' in model_args.model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer.pad_token_id = '<|endoftext|>' + tokenizer.pad_token = '<|endoftext|>' + else: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + ) + tokenizer = utils.add_task_specific_tokens(tokenizer) + dataset = SuperGlueDataset(tokenizer, data_args, training_args) + + if training_args.do_train: + for index in random.sample(range(len(dataset.train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {dataset.train_dataset[index]}.") + + if not dataset.multiple_choice: + if "llama" in model_args.model_name_or_path: + model_path = f'openlm-research/{model_args.model_name_or_path}' + config = AutoConfig.from_pretrained( + model_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + label2id=dataset.label2id, + id2label=dataset.id2label, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + trust_remote_code=True + ) + else: + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, + num_labels=dataset.num_labels, + finetuning_task=data_args.dataset_name, + revision=model_args.model_revision, + ) + + config.trigger = training_args.trigger + config.clean_labels = training_args.clean_labels + config.target_labels = training_args.target_labels + + if not dataset.multiple_choice: + model = get_model(model_args, TaskType.SEQUENCE_CLASSIFICATION, config) + else: + model = get_model(model_args, TaskType.MULTIPLE_CHOICE, config, fix_bert=True) + + # Initialize our Trainer + trainer = BaseTrainer( + model=model, + args=training_args, + train_dataset=dataset.train_dataset if training_args.do_train else None, + eval_dataset=dataset.eval_dataset if training_args.do_eval else None, + compute_metrics=dataset.compute_metrics, + tokenizer=tokenizer, + data_collator=dataset.data_collator, + test_key=dataset.test_key + ) + + + return trainer, None diff --git a/soft_prompt/tasks/superglue/utils.py b/soft_prompt/tasks/superglue/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..544e3f92b61f8fd6cae994d9f98677ec8e2d7fd9 --- /dev/null +++ b/soft_prompt/tasks/superglue/utils.py @@ -0,0 +1,51 @@ +import re, os +import string +from collections import defaultdict, Counter + +def load_from_cache(model_name): + path = os.path.join("hub/models", model_name) + if os.path.isdir(path): + return path + return model_name + +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r'\b(a|an|the)\b', ' ', text) + + def white_space_fix(text): + return ' '.join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return ''.join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + +def f1_score(prediction, ground_truth): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + return f1 + + +def exact_match_score(prediction, ground_truth): + return normalize_answer(prediction) == normalize_answer(ground_truth) + + +def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): + scores_for_ground_truths = [] + for ground_truth in ground_truths: + score = metric_fn(prediction, ground_truth) + scores_for_ground_truths.append(score) + return max(scores_for_ground_truths) \ No newline at end of file diff --git a/soft_prompt/tasks/utils.py b/soft_prompt/tasks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53d0bb97b383043df45e92a1bb5d577f7117abcc --- /dev/null +++ b/soft_prompt/tasks/utils.py @@ -0,0 +1,79 @@ +import os +import torch +from tqdm import tqdm +from tasks.glue.dataset import task_to_keys as glue_tasks +from tasks.superglue.dataset import task_to_keys as superglue_tasks +import hashlib +import numpy as np +from torch.nn.utils.rnn import pad_sequence + +GLUE_DATASETS = list(glue_tasks.keys()) +SUPERGLUE_DATASETS = list(superglue_tasks.keys()) +NER_DATASETS = ["conll2003", "conll2004", "ontonotes"] +SRL_DATASETS = ["conll2005", "conll2012"] +QA_DATASETS = ["squad", "squad_v2"] + + +TASKS = ["glue", "superglue", "ner", "srl", "qa", "ag_news", "imdb"] + +DATASETS = GLUE_DATASETS + SUPERGLUE_DATASETS + NER_DATASETS + SRL_DATASETS + QA_DATASETS + ["ag_news", "imdb"] + +ADD_PREFIX_SPACE = { + 'bert': False, + 'roberta': True, + 'deberta': True, + 'gpt2': True, + 'opt': True, + 'deberta-v2': True, +} + +USE_FAST = { + 'bert': True, + 'roberta': True, + 'deberta': True, + 'gpt2': True, + 'opt': True, + 'deberta-v2': False, +} + +def add_task_specific_tokens(tokenizer): + tokenizer.add_special_tokens({ + 'additional_special_tokens': ['[P]', '[T]', '[K]', '[Y]'] + }) + tokenizer.skey_token = '[K]' + tokenizer.skey_token_id = tokenizer.convert_tokens_to_ids('[K]') + tokenizer.prompt_token = '[T]' + tokenizer.prompt_token_id = tokenizer.convert_tokens_to_ids('[T]') + tokenizer.predict_token = '[P]' + tokenizer.predict_token_id = tokenizer.convert_tokens_to_ids('[P]') + # NOTE: BERT and RoBERTa tokenizers work properly if [X] is not a special token... + # tokenizer.lama_x = '[X]' + # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[X]') + # tokenizer.lama_y = '[Y]' + # tokenizer.lama_x_id = tokenizer.convert_tokens_to_ids('[Y]') + + # only for GPT2 + if 'gpt' in tokenizer.name_or_path or 'opt' in tokenizer.name_or_path: + tokenizer.mask_token = tokenizer.unk_token + tokenizer.pad_token = tokenizer.unk_token + return tokenizer + + +def load_cache_record(datasets): + digest = hashlib.md5("record".encode("utf-8")).hexdigest() # 16 byte binary + path = datasets["train"]._get_cache_file_path("").replace("cache-.arrow", f"cache-clean+poison-{digest}.arrow") + if not os.path.exists(path): + return torch.load(path) + return None + + + + + + + + + + + + \ No newline at end of file diff --git a/soft_prompt/training/__init__.py b/soft_prompt/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..abba471bd45ad30ebea2bc23363270d936f286d9 --- /dev/null +++ b/soft_prompt/training/__init__.py @@ -0,0 +1,2 @@ +from .trainer_base import BaseTrainer +from .trainer import Trainer \ No newline at end of file diff --git a/soft_prompt/training/trainer.py b/soft_prompt/training/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3d87d61f98a37df45bf757f3eb9c46ca9cf43d --- /dev/null +++ b/soft_prompt/training/trainer.py @@ -0,0 +1,3990 @@ +# coding=utf-8 +# Copyright 2020-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. +""" + +import contextlib +import functools +import glob +import inspect +import math +import os +import random +import re +import shutil +import sys +import time +import warnings +from collections.abc import Mapping +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union + +from tqdm.auto import tqdm + + +# Integrations must be imported before ML frameworks: +# isort: off +from transformers.integrations import ( + default_hp_search_backend, + get_reporting_integration_callbacks, + hp_params, + is_fairscale_available, + is_optuna_available, + is_ray_tune_available, + is_sigopt_available, + is_wandb_available, + run_hp_search_optuna, + run_hp_search_ray, + run_hp_search_sigopt, + run_hp_search_wandb, +) + +# isort: on + +import numpy as np +import torch +import torch.distributed as dist +from huggingface_hub import Repository, create_repo +from packaging import version +from torch import nn +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + +from transformers import __version__ +from transformers.configuration_utils import PretrainedConfig +from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from transformers.debug_utils import DebugOption, DebugUnderflowOverflow +from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled +from transformers.dependency_versions_check import dep_version_check +from transformers.modelcard import TrainingSummary +from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES +from transformers.optimization import Adafactor, get_scheduler +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_callback import ( + CallbackHandler, + DefaultFlowCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, +) +from transformers.trainer_pt_utils import ( + DistributedLengthGroupedSampler, + DistributedSamplerWithLoop, + DistributedTensorGatherer, + IterableDatasetShard, + LabelSmoother, + LengthGroupedSampler, + SequentialDistributedSampler, + ShardSampler, + distributed_broadcast_scalars, + distributed_concat, + find_batch_size, + get_model_param_count, + get_module_class_from_name, + get_parameter_names, + nested_concat, + nested_detach, + nested_numpify, + nested_truncate, + nested_xla_mesh_reduce, + reissue_pt_warnings, +) +from transformers.trainer_utils import ( + PREFIX_CHECKPOINT_DIR, + BestRun, + EvalLoopOutput, + EvalPrediction, + FSDPOption, + HPSearchBackend, + HubStrategy, + IntervalStrategy, + PredictionOutput, + RemoveColumnsCollator, + ShardedDDPOption, + TrainerMemoryTracker, + TrainOutput, + default_compute_objective, + default_hp_space, + denumpify_detensorize, + enable_full_determinism, + find_executable_batch_size, + get_last_checkpoint, + has_length, + number_of_arguments, + seed_worker, + set_seed, + speed_metrics, +) +from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments +from transformers.utils import ( + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + can_return_loss, + find_labels, + get_full_repo_name, + is_accelerate_available, + is_apex_available, + is_datasets_available, + is_in_notebook, + is_ipex_available, + is_peft_available, + is_safetensors_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_torch_compile_available, + is_torch_neuroncore_available, + is_torch_tpu_available, + logging, + strtobool, +) +from transformers.utils.generic import ContextManagers + + +_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10 + +DEFAULT_CALLBACKS = [DefaultFlowCallback] +DEFAULT_PROGRESS_CALLBACK = ProgressCallback + +if is_in_notebook(): + from transformers.utils.notebook import NotebookProgressCallback + + DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback + +if is_apex_available(): + from apex import amp + +if is_datasets_available(): + import datasets + +if is_torch_tpu_available(check_device=False): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + import torch_xla.distributed.parallel_loader as pl + +if is_fairscale_available(): + dep_version_check("fairscale") + import fairscale + from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP + from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP + from fairscale.nn.wrap import auto_wrap + from fairscale.optim import OSS + from fairscale.optim.grad_scaler import ShardedGradScaler + + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + from smdistributed.modelparallel import __version__ as SMP_VERSION + + IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") + + from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat +else: + IS_SAGEMAKER_MP_POST_1_10 = False + + +if is_safetensors_available(): + import safetensors.torch + + +if is_peft_available(): + from peft import PeftModel + + +skip_first_batches = None +if is_accelerate_available(): + from accelerate import __version__ as accelerate_version + + if version.parse(accelerate_version) >= version.parse("0.16"): + from accelerate import skip_first_batches + + from accelerate import Accelerator + from accelerate.utils import DistributedDataParallelKwargs + + +if TYPE_CHECKING: + import optuna + +logger = logging.get_logger(__name__) + + +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" +OPTIMIZER_NAME = "optimizer.pt" +SCHEDULER_NAME = "scheduler.pt" +SCALER_NAME = "scaler.pt" + + +class Trainer: + """ + Trainer is a simple but feature-complete training and eval loop for PyTorch, optimized for 🤗 Transformers. + + Args: + model ([`PreTrainedModel`] or `torch.nn.Module`, *optional*): + The model to train, evaluate or use for predictions. If not provided, a `model_init` must be passed. + + + + [`Trainer`] is optimized to work with the [`PreTrainedModel`] provided by the library. You can still use + your own models defined as `torch.nn.Module` as long as they work the same way as the 🤗 Transformers + models. + + + + args ([`TrainingArguments`], *optional*): + The arguments to tweak for training. Will default to a basic instance of [`TrainingArguments`] with the + `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided. + data_collator (`DataCollator`, *optional*): + The function to use to form a batch from a list of elements of `train_dataset` or `eval_dataset`. Will + default to [`default_data_collator`] if no `tokenizer` is provided, an instance of + [`DataCollatorWithPadding`] otherwise. + train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*): + The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. + + Note that if it's a `torch.utils.data.IterableDataset` with some randomization and you are training in a + distributed fashion, your iterable dataset should either use a internal attribute `generator` that is a + `torch.Generator` for the randomization that must be identical on all processes (and the Trainer will + manually set the seed of this `generator` at each epoch) or have a `set_epoch()` method that internally + sets the seed of the RNGs used. + eval_dataset (Union[`torch.utils.data.Dataset`, Dict[str, `torch.utils.data.Dataset`]), *optional*): + The dataset to use for evaluation. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. If it is a dictionary, it will evaluate on each + dataset prepending the dictionary key to the metric name. + tokenizer ([`PreTrainedTokenizerBase`], *optional*): + The tokenizer used to preprocess the data. If provided, will be used to automatically pad the inputs to the + maximum length when batching inputs, and it will be saved along the model to make it easier to rerun an + interrupted training or reuse the fine-tuned model. + model_init (`Callable[[], PreTrainedModel]`, *optional*): + A function that instantiates the model to be used. If provided, each call to [`~Trainer.train`] will start + from a new instance of the model as given by this function. + + The function may have zero argument, or a single one containing the optuna/Ray Tune/SigOpt trial object, to + be able to choose different architectures according to hyper parameters (such as layer count, sizes of + inner layers, dropout probabilities etc). + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function that will be used to compute metrics at evaluation. Must take a [`EvalPrediction`] and return + a dictionary string to metric values. + callbacks (List of [`TrainerCallback`], *optional*): + A list of callbacks to customize the training loop. Will add those to the list of default callbacks + detailed in [here](callback). + + If you want to remove one of the default callbacks used, use the [`Trainer.remove_callback`] method. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*): A tuple + containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your model + and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): + A function that preprocess the logits right before caching them at each evaluation step. Must take two + tensors, the logits and the labels, and return the logits once processed as desired. The modifications made + by this function will be reflected in the predictions received by `compute_metrics`. + + Note that the labels (second parameter) will be `None` if the dataset does not have them. + + Important attributes: + + - **model** -- Always points to the core model. If using a transformers model, it will be a [`PreTrainedModel`] + subclass. + - **model_wrapped** -- Always points to the most external model in case one or more other modules wrap the + original model. This is the model that should be used for the forward pass. For example, under `DeepSpeed`, + the inner model is wrapped in `DeepSpeed` and then again in `torch.nn.DistributedDataParallel`. If the inner + model hasn't been wrapped, then `self.model_wrapped` is the same as `self.model`. + - **is_model_parallel** -- Whether or not a model has been switched to a model parallel mode (different from + data parallelism, this means some of the model layers are split on different GPUs). + - **place_model_on_device** -- Whether or not to automatically place the model on the device - it will be set + to `False` if model parallel or deepspeed is used, or if the default + `TrainingArguments.place_model_on_device` is overridden to return `False` . + - **is_in_train** -- Whether or not a model is currently running `train` (e.g. when `evaluate` is called while + in `train`) + + """ + + from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + args: TrainingArguments = None, + data_collator: Optional[DataCollator] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + model_init: Optional[Callable[[], PreTrainedModel]] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ): + if args is None: + output_dir = "tmp_trainer" + logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") + args = TrainingArguments(output_dir=output_dir) + self.args = args + # Seed must be set before instantiating the model when using model + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.hp_name = None + self.is_in_train = False + + self.create_accelerator_and_postprocess() + + # memory metrics - must set up as early as possible + self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) + self._memory_tracker.start() + + # set the correct log level depending on the node + log_level = args.get_process_log_level() + logging.set_verbosity(log_level) + + # force device and distributed setup init explicitly + args._setup_devices + + if model is None: + if model_init is not None: + self.model_init = model_init + model = self.call_model_init() + else: + raise RuntimeError("`Trainer` requires either a `model` or `model_init` argument") + else: + if model_init is not None: + warnings.warn( + "`Trainer` requires either a `model` or `model_init` argument, but not both. `model_init` will" + " overwrite your model when calling the `train` method. This will become a fatal error in the next" + " release.", + FutureWarning, + ) + self.model_init = model_init + + if model.__class__.__name__ in MODEL_MAPPING_NAMES: + raise ValueError( + f"The model you have picked ({model.__class__.__name__}) cannot be used as is for training: it only " + "computes hidden states and does not accept any labels. You should choose a model with a head " + "suitable for your task like any of the `AutoModelForXxx` listed at " + "https://huggingface.co/docs/transformers/model_doc/auto." + ) + + if hasattr(model, "is_parallelizable") and model.is_parallelizable and model.model_parallel: + self.is_model_parallel = True + else: + self.is_model_parallel = False + + if getattr(model, "hf_device_map", None) is not None: + devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] + if len(devices) > 1: + self.is_model_parallel = True + else: + self.is_model_parallel = self.args.device != torch.device(devices[0]) + + # warn users + logger.info( + "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" + " to `True` to avoid any unexpected behavior such as device placement mismatching." + ) + + # At this stage the model is already loaded + if getattr(model, "is_quantized", False): + if getattr(model, "_is_quantized_training_enabled", False): + logger.info( + "The model is loaded in 8-bit precision. To train this model you need to add additional modules" + " inside the model such as adapters using `peft` library and freeze the model weights. Please" + " check " + " the examples in https://github.com/huggingface/peft for more details." + ) + else: + raise ValueError( + "The model you want to train is loaded in 8-bit precision. if you want to fine-tune an 8-bit" + " model, please make sure that you have installed `bitsandbytes>=0.37.0`. " + ) + + # Setup Sharded DDP training + self.sharded_ddp = None + if len(args.sharded_ddp) > 0: + if self.is_deepspeed_enabled: + raise ValueError( + "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if len(args.fsdp) > 0: + raise ValueError( + "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." + ) + if args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Using sharded DDP only works in distributed training.") + elif not is_fairscale_available(): + raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") + elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: + raise ImportError( + "Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " + f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." + ) + elif ShardedDDPOption.SIMPLE in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.SIMPLE + elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 + elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: + self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 + + self.fsdp = None + if len(args.fsdp) > 0: + if self.is_deepspeed_enabled: + raise ValueError( + "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." + ) + if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Using fsdp only works in distributed training.") + + # dep_version_check("torch>=1.12.0") + # Would have to update setup.py with torch>=1.12.0 + # which isn't ideally given that it will force people not using FSDP to also use torch>=1.12.0 + # below is the current alternative. + if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.12.0"): + raise ValueError("FSDP requires PyTorch >= 1.12.0") + + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, ShardingStrategy + + if FSDPOption.FULL_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.FULL_SHARD + elif FSDPOption.SHARD_GRAD_OP in args.fsdp: + self.fsdp = ShardingStrategy.SHARD_GRAD_OP + elif FSDPOption.NO_SHARD in args.fsdp: + self.fsdp = ShardingStrategy.NO_SHARD + + self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE + if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get( + "backward_prefetch", [] + ): + self.backward_prefetch = BackwardPrefetch.BACKWARD_POST + + self.forward_prefetch = False + if self.args.fsdp_config.get("forward_prefect", False): + self.forward_prefetch = True + + self.limit_all_gathers = False + if self.args.fsdp_config.get("limit_all_gathers", False): + self.limit_all_gathers = True + + # one place to sort out whether to place the model on device or not + # postpone switching model to cuda when: + # 1. MP - since we are trying to fit a much bigger than 1 gpu model + # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, + # and we only use deepspeed for training at the moment + # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first + # 4. Sharded DDP - same as MP + # 5. FSDP - same as MP + self.place_model_on_device = args.place_model_on_device + if ( + self.is_model_parallel + or self.is_deepspeed_enabled + or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) + or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) + or (self.fsdp is not None) + or self.is_fsdp_enabled + ): + self.place_model_on_device = False + + default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) + self.data_collator = data_collator if data_collator is not None else default_collator + self.train_dataset = train_dataset + self.eval_dataset = eval_dataset + self.tokenizer = tokenizer + + if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False): + self._move_model_to_device(model, args.device) + + # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs + if self.is_model_parallel: + self.args._n_gpu = 1 + + # later use `self.model is self.model_wrapped` to check if it's wrapped or not + self.model_wrapped = model + self.model = model + + self.compute_metrics = compute_metrics + self.preprocess_logits_for_metrics = preprocess_logits_for_metrics + self.optimizer, self.lr_scheduler = optimizers + if model_init is not None and (self.optimizer is not None or self.lr_scheduler is not None): + raise RuntimeError( + "Passing a `model_init` is incompatible with providing the `optimizers` argument. " + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + if is_torch_tpu_available() and self.optimizer is not None: + for param in self.model.parameters(): + model_device = param.device + break + for param_group in self.optimizer.param_groups: + if len(param_group["params"]) > 0: + optimizer_device = param_group["params"][0].device + break + if model_device != optimizer_device: + raise ValueError( + "The model and the optimizer parameters are not on the same device, which probably means you" + " created an optimizer around your model **before** putting on the device and passing it to the" + " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" + " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." + ) + if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( + self.optimizer is not None or self.lr_scheduler is not None + ): + raise RuntimeError( + "Passing `optimizers` is not allowed if Fairscale, Deepspeed or PyTorch FSDP is enabled." + "You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method." + ) + default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to) + callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks + self.callback_handler = CallbackHandler( + callbacks, self.model, self.tokenizer, self.optimizer, self.lr_scheduler + ) + self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK) + + # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. + self._loggers_initialized = False + + # Create clone of distant repo and output directory if needed + if self.args.push_to_hub: + self.init_git_repo(at_init=True) + # In case of pull, we need to make sure every process has the latest. + if is_torch_tpu_available(): + xm.rendezvous("init git repo") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + + if self.args.should_save: + os.makedirs(self.args.output_dir, exist_ok=True) + + if not callable(self.data_collator) and callable(getattr(self.data_collator, "collate_batch", None)): + raise ValueError("The `data_collator` should be a simple callable (function, class with `__call__`).") + + if args.max_steps > 0: + logger.info("max_steps is given, it will override any value given in num_train_epochs") + + if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: + raise ValueError( + "The train_dataset does not implement __len__, max_steps has to be specified. " + "The number of steps needs to be known in advance for the learning rate scheduler." + ) + + if ( + train_dataset is not None + and isinstance(train_dataset, torch.utils.data.IterableDataset) + and args.group_by_length + ): + raise ValueError("the `--group_by_length` option is only available for `Dataset`, not `IterableDataset") + + self._signature_columns = None + + # Mixed precision setup + self.use_apex = False + self.use_cuda_amp = False + self.use_cpu_amp = False + + # Mixed precision setup for SageMaker Model Parallel + if is_sagemaker_mp_enabled(): + # BF16 + model parallelism in SageMaker: currently not supported, raise an error + if args.bf16: + raise ValueError("SageMaker Model Parallelism does not support BF16 yet. Please use FP16 instead ") + + if IS_SAGEMAKER_MP_POST_1_10: + # When there's mismatch between SMP config and trainer argument, use SMP config as truth + if args.fp16 != smp.state.cfg.fp16: + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}," + f"but FP16 provided in trainer argument is {args.fp16}," + f"setting to {smp.state.cfg.fp16}" + ) + args.fp16 = smp.state.cfg.fp16 + else: + # smp < 1.10 does not support fp16 in trainer. + if hasattr(smp.state.cfg, "fp16"): + logger.warning( + f"FP16 provided in SM_HP_MP_PARAMETERS is {smp.state.cfg.fp16}, " + "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." + ) + + if (args.fp16 or args.bf16) and self.sharded_ddp is not None: + if args.half_precision_backend == "auto": + if args.device == torch.device("cpu"): + if args.fp16: + raise ValueError("Tried to use `fp16` but it is not supported on cpu") + elif _is_native_cpu_amp_available: + args.half_precision_backend = "cpu_amp" + else: + raise ValueError("Tried to use cpu amp but native cpu amp is not available") + else: + args.half_precision_backend = "cuda_amp" + + logger.info(f"Using {args.half_precision_backend} half precision backend") + + self.do_grad_scaling = False + if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): + # deepspeed and SageMaker Model Parallel manage their own half precision + if self.sharded_ddp is not None: + if args.half_precision_backend == "cuda_amp": + self.use_cuda_amp = True + self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + # bf16 does not need grad scaling + self.do_grad_scaling = self.amp_dtype == torch.float16 + if self.do_grad_scaling: + if self.sharded_ddp is not None: + self.scaler = ShardedGradScaler() + elif self.fsdp is not None: + from torch.distributed.fsdp.sharded_grad_scaler import ( + ShardedGradScaler as FSDPShardedGradScaler, + ) + + self.scaler = FSDPShardedGradScaler() + elif is_torch_tpu_available(): + from torch_xla.amp import GradScaler + + self.scaler = GradScaler() + else: + self.scaler = torch.cuda.amp.GradScaler() + elif args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + elif args.half_precision_backend == "apex": + if not is_apex_available(): + raise ImportError( + "Using FP16 with APEX but APEX is not installed, please refer to" + " https://www.github.com/nvidia/apex." + ) + self.use_apex = True + + # FP16 + model parallelism in SageMaker: gradient clipping does not work for now so we raise a helpful error. + if ( + is_sagemaker_mp_enabled() + and self.use_cuda_amp + and args.max_grad_norm is not None + and args.max_grad_norm > 0 + ): + raise ValueError( + "SageMaker Model Parallelism in mixed precision mode does not support gradient clipping yet. Pass " + "along 'max_grad_norm': 0 in your hyperparameters." + ) + + # Label smoothing + if self.args.label_smoothing_factor != 0: + self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor) + else: + self.label_smoother = None + + self.state = TrainerState( + is_local_process_zero=self.is_local_process_zero(), + is_world_process_zero=self.is_world_process_zero(), + ) + + self.control = TrainerControl() + # Internal variable to count flos in each process, will be accumulated in `self.state.total_flos` then + # returned to 0 every time flos need to be logged + self.current_flos = 0 + self.hp_search_backend = None + self.use_tune_checkpoints = False + default_label_names = find_labels(self.model.__class__) + self.label_names = default_label_names if self.args.label_names is None else self.args.label_names + self.can_return_loss = can_return_loss(self.model.__class__) + self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) + + # Internal variables to keep track of the original batch size + self._train_batch_size = args.train_batch_size + + # very last + self._memory_tracker.stop_and_update_metrics() + + # torch.compile + if args.torch_compile and not is_torch_compile_available(): + raise RuntimeError("Using torch.compile requires PyTorch 2.0 or higher.") + + def add_callback(self, callback): + """ + Add a callback to the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will instantiate a member of that class. + """ + self.callback_handler.add_callback(callback) + + def pop_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`] and returns it. + + If the callback is not found, returns `None` (and no error is raised). + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will pop the first member of that class found in the list of callbacks. + + Returns: + [`~transformer.TrainerCallback`]: The callback removed, if found. + """ + return self.callback_handler.pop_callback(callback) + + def remove_callback(self, callback): + """ + Remove a callback from the current list of [`~transformer.TrainerCallback`]. + + Args: + callback (`type` or [`~transformer.TrainerCallback`]): + A [`~transformer.TrainerCallback`] class or an instance of a [`~transformer.TrainerCallback`]. In the + first case, will remove the first member of that class found in the list of callbacks. + """ + self.callback_handler.remove_callback(callback) + + def _move_model_to_device(self, model, device): + model = model.to(device) + # Moving a model to an XLA device disconnects the tied weights, so we have to retie them. + if self.args.parallel_mode == ParallelMode.TPU and hasattr(model, "tie_weights"): + model.tie_weights() + + def _set_signature_columns_if_needed(self): + if self._signature_columns is None: + # Inspect model forward signature to keep only the arguments it accepts. + signature = inspect.signature(self.model.forward) + self._signature_columns = list(signature.parameters.keys()) + # Labels may be named label or label_ids, the default data collator handles that. + self._signature_columns += list(set(["label", "label_ids"] + self.label_names)) + + def _remove_unused_columns(self, dataset: "datasets.Dataset", description: Optional[str] = None): + if not self.args.remove_unused_columns: + return dataset + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + ignored_columns = list(set(dataset.column_names) - set(signature_columns)) + if len(ignored_columns) > 0: + dset_description = "" if description is None else f"in the {description} set" + logger.info( + f"The following columns {dset_description} don't have a corresponding argument in " + f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}." + f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, " + " you can safely ignore this message." + ) + + columns = [k for k in signature_columns if k in dataset.column_names] + + if version.parse(datasets.__version__) < version.parse("1.4.0"): + dataset.set_format( + type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"] + ) + return dataset + else: + return dataset.remove_columns(ignored_columns) + + def _get_collator_with_removed_columns( + self, data_collator: Callable, description: Optional[str] = None + ) -> Callable: + """Wrap the data collator in a callable removing unused columns.""" + if not self.args.remove_unused_columns: + return data_collator + self._set_signature_columns_if_needed() + signature_columns = self._signature_columns + + remove_columns_collator = RemoveColumnsCollator( + data_collator=data_collator, + signature_columns=signature_columns, + logger=logger, + description=description, + model_name=self.model.__class__.__name__, + ) + return remove_columns_collator + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.train_dataset is None or not has_length(self.train_dataset): + return None + + generator = None + if self.args.world_size <= 1: + generator = torch.Generator() + # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with + # `args.seed`) if data_seed isn't provided. + # Further on in this method, we default to `args.seed` instead. + if self.args.data_seed is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + else: + seed = self.args.data_seed + generator.manual_seed(seed) + + seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed + + # Build the sampler. + if self.args.group_by_length: + if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): + lengths = ( + self.train_dataset[self.args.length_column_name] + if self.args.length_column_name in self.train_dataset.column_names + else None + ) + else: + lengths = None + model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None + if self.args.world_size <= 1: + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + generator=generator, + ) + else: + return DistributedLengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + lengths=lengths, + model_input_name=model_input_name, + seed=seed, + ) + + else: + if self.args.world_size <= 1: + return RandomSampler(self.train_dataset, generator=generator) + elif ( + self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] + and not self.args.dataloader_drop_last + ): + # Use a loop for TPUs when drop_last is False to have all batches have the same size. + return DistributedSamplerWithLoop( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=seed, + ) + else: + return DistributedSampler( + self.train_dataset, + num_replicas=self.args.world_size, + rank=self.args.process_index, + seed=seed, + ) + + def get_train_dataloader(self) -> DataLoader: + """ + Returns the training [`~torch.utils.data.DataLoader`]. + + Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed + training if necessary) otherwise. + + Subclass and override this method if you want to inject some custom behavior. + """ + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + data_collator = self.data_collator + if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): + train_dataset = self._remove_unused_columns(train_dataset, description="training") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="training") + + if isinstance(train_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + train_dataset = IterableDatasetShard( + train_dataset, + batch_size=self._train_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + train_sampler = self._get_train_sampler() + + return DataLoader( + train_dataset, + batch_size=self._train_batch_size, + sampler=train_sampler, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + worker_init_fn=seed_worker, + ) + + def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: + # Deprecated code + if self.args.use_legacy_prediction_loop: + if is_torch_tpu_available(): + return SequentialDistributedSampler( + eval_dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal() + ) + elif is_sagemaker_mp_enabled(): + return SequentialDistributedSampler( + eval_dataset, + num_replicas=smp.dp_size(), + rank=smp.dp_rank(), + batch_size=self.args.per_device_eval_batch_size, + ) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + return SequentialDistributedSampler(eval_dataset) + else: + return SequentialSampler(eval_dataset) + + if self.args.world_size <= 1: + return SequentialSampler(eval_dataset) + else: + return ShardSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + """ + Returns the evaluation [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + eval_dataset (`torch.utils.data.Dataset`, *optional*): + If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted + by the `model.forward()` method are automatically removed. It must implement `__len__`. + """ + if eval_dataset is None and self.eval_dataset is None: + raise ValueError("Trainer: evaluation requires an eval_dataset.") + eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset + data_collator = self.data_collator + + if is_datasets_available() and isinstance(eval_dataset, datasets.Dataset): + eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") + + if isinstance(eval_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + eval_dataset = IterableDatasetShard( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + eval_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + + return DataLoader( + eval_dataset, + sampler=eval_sampler, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: + """ + Returns the test [`~torch.utils.data.DataLoader`]. + + Subclass and override this method if you want to inject some custom behavior. + + Args: + test_dataset (`torch.utils.data.Dataset`, *optional*): + The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the + `model.forward()` method are automatically removed. It must implement `__len__`. + """ + data_collator = self.data_collator + + if is_datasets_available() and isinstance(test_dataset, datasets.Dataset): + test_dataset = self._remove_unused_columns(test_dataset, description="test") + else: + data_collator = self._get_collator_with_removed_columns(data_collator, description="test") + + if isinstance(test_dataset, torch.utils.data.IterableDataset): + if self.args.world_size > 1: + test_dataset = IterableDatasetShard( + test_dataset, + batch_size=self.args.eval_batch_size, + drop_last=self.args.dataloader_drop_last, + num_processes=self.args.world_size, + process_index=self.args.process_index, + ) + return DataLoader( + test_dataset, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + test_sampler = self._get_eval_sampler(test_dataset) + + # We use the same batch_size as for eval. + return DataLoader( + test_dataset, + sampler=test_sampler, + batch_size=self.args.eval_batch_size, + collate_fn=data_collator, + drop_last=self.args.dataloader_drop_last, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + def create_optimizer_and_scheduler(self, num_training_steps: int): + """ + Setup the optimizer and the learning rate scheduler. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method (or `create_optimizer` and/or + `create_scheduler`) in a subclass. + """ + self.create_optimizer() + if IS_SAGEMAKER_MP_POST_1_10 and smp.state.cfg.fp16: + # If smp >= 1.10 and fp16 is enabled, we unwrap the optimizer + optimizer = self.optimizer.optimizer + else: + optimizer = self.optimizer + self.create_scheduler(num_training_steps=num_training_steps, optimizer=optimizer) + + def create_optimizer(self): + """ + Setup the optimizer. + + We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the + Trainer's init through `optimizers`, or subclass and override this method in a subclass. + """ + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + + if self.optimizer is None: + decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [ + p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) + + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer = OSS( + params=optimizer_grouped_parameters, + optim=optimizer_cls, + **optimizer_kwargs, + ) + else: + self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + if optimizer_cls.__name__ == "Adam8bit": + import bitsandbytes + + manager = bitsandbytes.optim.GlobalOptimManager.get_instance() + + skipped = 0 + for module in opt_model.modules(): + if isinstance(module, nn.Embedding): + skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) + logger.info(f"skipped {module}: {skipped/2**20}M params") + manager.register_module_override(module, "weight", {"optim_bits": 32}) + logger.debug(f"bitsandbytes: will optimize {module} in fp32") + logger.info(f"skipped: {skipped/2**20}M params") + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer(self.optimizer) + + return self.optimizer + + @staticmethod + def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: + """ + Returns the optimizer class and optimizer parameters based on the training arguments. + + Args: + args (`transformers.training_args.TrainingArguments`): + The training arguments for the training session. + + """ + + # parse args.optim_args + optim_args = {} + if args.optim_args: + for mapping in args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optimizer_kwargs = {"lr": args.learning_rate} + + adam_kwargs = { + "betas": (args.adam_beta1, args.adam_beta2), + "eps": args.adam_epsilon, + } + if args.optim == OptimizerNames.ADAFACTOR: + optimizer_cls = Adafactor + optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) + elif args.optim == OptimizerNames.ADAMW_HF: + from transformers.optimization import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + elif args.optim in [OptimizerNames.ADAMW_TORCH, OptimizerNames.ADAMW_TORCH_FUSED]: + from torch.optim import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + if args.optim == OptimizerNames.ADAMW_TORCH_FUSED: + optimizer_kwargs.update({"fused": True}) + elif args.optim == OptimizerNames.ADAMW_TORCH_XLA: + try: + from torch_xla.amp.syncfree import AdamW + + optimizer_cls = AdamW + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer failed to import syncfree AdamW from torch_xla.") + elif args.optim == OptimizerNames.ADAMW_APEX_FUSED: + try: + from apex.optimizers import FusedAdam + + optimizer_cls = FusedAdam + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") + elif args.optim in [ + OptimizerNames.ADAMW_BNB, + OptimizerNames.ADAMW_8BIT, + OptimizerNames.PAGED_ADAMW, + OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.LION, + OptimizerNames.LION_8BIT, + OptimizerNames.PAGED_LION, + OptimizerNames.PAGED_LION_8BIT, + ]: + try: + from bitsandbytes.optim import AdamW, Lion + + is_paged = False + optim_bits = 32 + optimizer_cls = None + additional_optim_kwargs = adam_kwargs + if "paged" in args.optim: + is_paged = True + if "8bit" in args.optim: + optim_bits = 8 + if "adam" in args.optim: + optimizer_cls = AdamW + elif "lion" in args.optim: + optimizer_cls = Lion + additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} + + bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} + optimizer_kwargs.update(additional_optim_kwargs) + optimizer_kwargs.update(bnb_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!") + elif args.optim == OptimizerNames.ADAMW_BNB: + try: + from bitsandbytes.optim import Adam8bit + + optimizer_cls = Adam8bit + optimizer_kwargs.update(adam_kwargs) + except ImportError: + raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") + elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: + try: + from torchdistx.optimizers import AnyPrecisionAdamW + + optimizer_cls = AnyPrecisionAdamW + optimizer_kwargs.update(adam_kwargs) + + # TODO Change dtypes back to M=FP32, Var = BF16, Kahan = False once they can be cast together in torchdistx. + optimizer_kwargs.update( + { + "use_kahan_summation": strtobool(optim_args.get("use_kahan_summation", "False")), + "momentum_dtype": getattr(torch, optim_args.get("momentum_dtype", "float32")), + "variance_dtype": getattr(torch, optim_args.get("variance_dtype", "float32")), + "compensation_buffer_dtype": getattr( + torch, optim_args.get("compensation_buffer_dtype", "bfloat16") + ), + } + ) + except ImportError: + raise ValueError("Please install https://github.com/pytorch/torchdistx") + elif args.optim == OptimizerNames.SGD: + optimizer_cls = torch.optim.SGD + elif args.optim == OptimizerNames.ADAGRAD: + optimizer_cls = torch.optim.Adagrad + else: + raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") + return optimizer_cls, optimizer_kwargs + + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + """ + if self.lr_scheduler is None: + self.lr_scheduler = get_scheduler( + self.args.lr_scheduler_type, + optimizer=self.optimizer if optimizer is None else optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + return self.lr_scheduler + + def num_examples(self, dataloader: DataLoader) -> int: + """ + Helper to get number of samples in a [`~torch.utils.data.DataLoader`] by accessing its dataset. When + dataloader.dataset does not exist or has no length, estimates as best it can + """ + try: + dataset = dataloader.dataset + # Special case for IterableDatasetShard, we need to dig deeper + if isinstance(dataset, IterableDatasetShard): + return len(dataloader.dataset.dataset) + return len(dataloader.dataset) + except (NameError, AttributeError, TypeError): # no dataset or length, estimate by length of dataloader + return len(dataloader) * self.args.per_device_train_batch_size + + def _hp_search_setup(self, trial: Union["optuna.Trial", Dict[str, Any]]): + """HP search setup code""" + self._trial = trial + + if self.hp_search_backend is None or trial is None: + return + if self.hp_search_backend == HPSearchBackend.OPTUNA: + params = self.hp_space(trial) + elif self.hp_search_backend == HPSearchBackend.RAY: + params = trial + params.pop("wandb", None) + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + params = {k: int(v) if isinstance(v, str) else v for k, v in trial.assignments.items()} + elif self.hp_search_backend == HPSearchBackend.WANDB: + params = trial + + for key, value in params.items(): + if not hasattr(self.args, key): + logger.warning( + f"Trying to set {key} in the hyperparameter search but there is no corresponding field in" + " `TrainingArguments`." + ) + continue + old_attr = getattr(self.args, key, None) + # Casting value to the proper type + if old_attr is not None: + value = type(old_attr)(value) + setattr(self.args, key, value) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + logger.info(f"Trial: {trial.params}") + if self.hp_search_backend == HPSearchBackend.SIGOPT: + logger.info(f"SigOpt Assignments: {trial.assignments}") + if self.hp_search_backend == HPSearchBackend.WANDB: + logger.info(f"W&B Sweep parameters: {trial}") + if self.is_deepspeed_enabled: + if self.args.deepspeed is None: + raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") + # Rebuild the deepspeed config to reflect the updated training parameters + from accelerate.utils import DeepSpeedPlugin + + from transformers.deepspeed import HfTrainerDeepSpeedConfig + + self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) + self.args.hf_deepspeed_config.trainer_config_process(self.args) + self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) + self.create_accelerator_and_postprocess() + + def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): + if self.hp_search_backend is None or trial is None: + return + self.objective = self.compute_objective(metrics.copy()) + if self.hp_search_backend == HPSearchBackend.OPTUNA: + import optuna + + trial.report(self.objective, step) + if trial.should_prune(): + self.callback_handler.on_train_end(self.args, self.state, self.control) + raise optuna.TrialPruned() + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + if self.control.should_save: + self._tune_save_checkpoint() + tune.report(objective=self.objective, **metrics) + + def _tune_save_checkpoint(self): + from ray import tune + + if not self.use_tune_checkpoints: + return + with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir: + output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}") + self.save_model(output_dir, _internal_call=True) + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + + def call_model_init(self, trial=None): + model_init_argcount = number_of_arguments(self.model_init) + if model_init_argcount == 0: + model = self.model_init() + elif model_init_argcount == 1: + model = self.model_init(trial) + else: + raise RuntimeError("model_init should have 0 or 1 argument.") + + if model is None: + raise RuntimeError("model_init should not return None.") + + return model + + def torch_jit_model_eval(self, model, dataloader, training=False): + if not training: + if dataloader is None: + logger.warning("failed to use PyTorch jit mode due to current dataloader is none.") + return model + example_batch = next(iter(dataloader)) + example_batch = self._prepare_inputs(example_batch) + try: + jit_model = model.eval() + with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]): + if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"): + if isinstance(example_batch, dict): + jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) + else: + jit_model = torch.jit.trace( + jit_model, + example_kwarg_inputs={key: example_batch[key] for key in example_batch}, + strict=False, + ) + else: + jit_inputs = [] + for key in example_batch: + example_tensor = torch.ones_like(example_batch[key]) + jit_inputs.append(example_tensor) + jit_inputs = tuple(jit_inputs) + jit_model = torch.jit.trace(jit_model, jit_inputs, strict=False) + jit_model = torch.jit.freeze(jit_model) + with torch.no_grad(): + jit_model(**example_batch) + jit_model(**example_batch) + model = jit_model + self.use_cpu_amp = False + self.use_cuda_amp = False + except (RuntimeError, TypeError, ValueError, NameError, IndexError) as e: + logger.warning(f"failed to use PyTorch jit mode due to: {e}.") + + return model + + def ipex_optimize_model(self, model, training=False, dtype=torch.float32): + if not is_ipex_available(): + raise ImportError( + "Using IPEX but IPEX is not installed or IPEX's version does not match current PyTorch, please refer" + " to https://github.com/intel/intel-extension-for-pytorch." + ) + + import intel_extension_for_pytorch as ipex + + if not training: + model.eval() + dtype = torch.bfloat16 if not self.is_in_train and self.args.bf16_full_eval else dtype + # conv_bn_folding is disabled as it fails in symbolic tracing, resulting in ipex warnings + model = ipex.optimize(model, dtype=dtype, level="O1", conv_bn_folding=False, inplace=not self.is_in_train) + else: + if not model.training: + model.train() + model, self.optimizer = ipex.optimize( + model, dtype=dtype, optimizer=self.optimizer, inplace=True, level="O1" + ) + + return model + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.use_ipex: + dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 + model = self.ipex_optimize_model(model, training, dtype=dtype) + + if is_sagemaker_mp_enabled(): + # Wrapping the base model twice in a DistributedModel will raise an error. + if isinstance(self.model_wrapped, smp.model.DistributedModel): + return self.model_wrapped + return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) + + # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again + if unwrap_model(model) is not model: + return model + + # Mixed precision training with apex (torch < 1.6) + if self.use_apex and training: + model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) + + # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP + if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): + model = nn.DataParallel(model) + + if self.args.jit_mode_eval: + start_time = time.time() + model = self.torch_jit_model_eval(model, dataloader, training) + self.jit_compilation_time = round(time.time() - start_time, 4) + + # Note: in torch.distributed mode, there's no point in wrapping the model + # inside a DistributedDataParallel as we'll be under `no_grad` anyways. + if not training: + return model + + # Distributed training (should be after apex fp16 initialization) + if self.sharded_ddp is not None: + # Sharded DDP! + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + model = ShardedDDP(model, self.optimizer) + else: + mixed_precision = self.args.fp16 or self.args.bf16 + cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp + zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 + # XXX: Breaking the self.model convention but I see no way around it for now. + if ShardedDDPOption.AUTO_WRAP in self.args.sharded_ddp: + model = auto_wrap(model) + self.model = model = FullyShardedDDP( + model, + mixed_precision=mixed_precision, + reshard_after_forward=zero_3, + cpu_offload=cpu_offload, + ).to(self.args.device) + # Distributed training using PyTorch FSDP + elif self.fsdp is not None and self.args.fsdp_config["xla"]: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, + ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + if self.args.fsdp_config["fsdp_min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + ) + elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + transformer_cls_to_wrap = set() + for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + return FSDP(checkpoint_module(m), *args, **kwargs) + + # Wrap the base model with an outer FSDP wrapper + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) + + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step + elif is_sagemaker_dp_enabled(): + model = nn.parallel.DistributedDataParallel( + model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] + ) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + if is_torch_neuroncore_available(): + return model + kwargs = {} + if self.args.ddp_find_unused_parameters is not None: + kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters + elif isinstance(model, PreTrainedModel): + # find_unused_parameters breaks checkpointing as per + # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 + kwargs["find_unused_parameters"] = not model.is_gradient_checkpointing + else: + kwargs["find_unused_parameters"] = True + + if self.args.ddp_bucket_cap_mb is not None: + kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb + + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + + return model + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + ignore_keys_for_eval (`List[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + kwargs: + Additional keyword arguments used to hide deprecated arguments + """ + if resume_from_checkpoint is False: + resume_from_checkpoint = None + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + args = self.args + + self.is_in_train = True + + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + self._move_model_to_device(self.model, args.device) + + if "model_path" in kwargs: + resume_from_checkpoint = kwargs.pop("model_path") + warnings.warn( + "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " + "instead.", + FutureWarning, + ) + if len(kwargs) > 0: + raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + # This might change the seed so needs to run first. + self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size + + # Model re-init + model_reloaded = False + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = self.call_model_init(trial) + model_reloaded = True + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + + if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled: + self._load_from_checkpoint(resume_from_checkpoint) + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + inner_training_loop = find_executable_batch_size( + self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size + ) + + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + + def _inner_training_loop( + self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None + ): + self.accelerator.free_memory() + self._train_batch_size = batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + + len_dataloader = None + if has_length(train_dataloader): + len_dataloader = len(train_dataloader) + num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + num_examples = self.num_examples(train_dataloader) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs + elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_examples = total_train_batch_size * args.max_steps + num_train_samples = args.max_steps * total_train_batch_size + else: + raise ValueError( + "args.max_steps must be set to a positive value if dataloader does not have a length, was" + f" {args.max_steps}" + ) + + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps and args.logging_steps < 1: + args.logging_steps = math.ceil(max_steps * args.logging_steps) + if args.eval_steps and args.eval_steps < 1: + args.eval_steps = math.ceil(max_steps * args.eval_steps) + if args.save_steps and args.save_steps < 1: + args.save_steps = math.ceil(max_steps * args.save_steps) + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP" + " (torch.distributed.launch)." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = ( + self.sharded_ddp is not None + and self.sharded_ddp != ShardedDDPOption.SIMPLE + or is_sagemaker_mp_enabled() + or self.fsdp is not None + ) + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState() + self.state.is_hyper_param_search = trial is not None + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + model = self._wrap_model(self.model_wrapped) + + if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + self._load_from_checkpoint(resume_from_checkpoint, model) + + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # deepspeed ckpt loading + if resume_from_checkpoint is not None and self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + if skip_first_batches is None: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," + " you can install the latest version of Accelerate with `pip install -U accelerate`.You can" + " also add the `--ignore_data_skip` flag to your launch command, but you will resume the" + " training on data already seen by your model." + ) + else: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) + if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: + steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) + steps_trained_progress_bar.set_description("Skipping the first batches") + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + if self.hp_name is not None and self._trial is not None: + # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial + # parameter to Train when using DDP. + self.state.trial_name = self.hp_name(self._trial) + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. + if not args.ignore_data_skip: + for epoch in range(epochs_trained): + is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( + train_dataloader.sampler, RandomSampler + ) + if is_torch_less_than_1_11 or not is_random_sampler: + # We just need to begin an iteration to create the randomization of the sampler. + # That was before PyTorch 1.11 however... + for _ in train_dataloader: + break + else: + # Otherwise we need to call the whooooole sampler cause there is some random operation added + # AT THE VERY END! + _ = list(train_dataloader.sampler) + + total_batched_samples = 0 + for epoch in range(epochs_trained, num_train_epochs): + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch) + elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): + train_dataloader.dataset.set_epoch(epoch) + + if is_torch_tpu_available(): + parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) + epoch_iterator = parallel_loader + else: + epoch_iterator = train_dataloader + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) + if len_dataloader is not None + else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + + rng_to_sync = False + steps_skipped = 0 + if skip_first_batches is not None and steps_trained_in_current_epoch > 0: + epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) + steps_skipped = steps_trained_in_current_epoch + steps_trained_in_current_epoch = 0 + rng_to_sync = True + + step = -1 + for step, inputs in enumerate(epoch_iterator): + total_batched_samples += 1 + if rng_to_sync: + self._load_rng_state(resume_from_checkpoint) + rng_to_sync = False + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + with self.accelerator.accumulate(model): + tr_loss_step = self.training_step(model, inputs) + + if ( + args.logging_nan_inf_filter + and not is_torch_tpu_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + # should this be under the accumulate context manager? + # the `or` condition of `steps_in_epoch <= args.gradient_accumulation_steps` is not covered + # in accelerate + if total_batched_samples % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps + and (step + 1) == steps_in_epoch + ): + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0: + # deepspeed does its own clipping + + if self.do_grad_scaling: + # Reduce gradients first for XLA + if is_torch_tpu_available(): + gradients = xm._fetch_gradients(self.optimizer) + xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) + # AMP: gradients need unscaling + self.scaler.unscale_(self.optimizer) + + if is_sagemaker_mp_enabled() and args.fp16: + self.optimizer.clip_master_grads(args.max_grad_norm) + elif hasattr(self.optimizer, "clip_grad_norm"): + # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping + self.optimizer.clip_grad_norm(args.max_grad_norm) + elif hasattr(model, "clip_grad_norm_"): + # Some models (like FullyShardedDDP) have a specific way to do gradient clipping + model.clip_grad_norm_(args.max_grad_norm) + elif self.use_apex: + # Revert to normal clipping otherwise, handling Apex or full precision + nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), + args.max_grad_norm, + ) + + # Optimizer step + optimizer_was_run = True + if is_torch_tpu_available(): + if self.do_grad_scaling: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + xm.optimizer_step(self.optimizer) + elif self.do_grad_scaling: + scale_before = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler.get_scale() + optimizer_was_run = scale_before <= scale_after + else: + self.optimizer.step() + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped + + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + if step < 0: + logger.warning( + "There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_tpu_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sur the model has been saved by process 0. + if is_torch_tpu_available(): + xm.rendezvous("load_best_model_at_end") + elif args.parallel_mode == ParallelMode.DISTRIBUTED: + dist.barrier() + elif is_sagemaker_mp_enabled(): + smp.barrier() + + self._load_best_model() + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + train_loss = self._total_loss_scalar / self.state.global_step + + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + run_dir = self._get_output_dir(trial) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir) + + # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. + if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: + for checkpoint in checkpoints_sorted: + if checkpoint != self.state.best_model_checkpoint: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + return TrainOutput(self.state.global_step, train_loss, metrics) + + def _get_output_dir(self, trial): + if self.hp_search_backend is not None and trial is not None: + if self.hp_search_backend == HPSearchBackend.OPTUNA: + run_id = trial.number + elif self.hp_search_backend == HPSearchBackend.RAY: + from ray import tune + + run_id = tune.get_trial_id() + elif self.hp_search_backend == HPSearchBackend.SIGOPT: + run_id = trial.id + elif self.hp_search_backend == HPSearchBackend.WANDB: + import wandb + + run_id = wandb.run.id + run_name = self.hp_name(trial) if self.hp_name is not None else f"run-{run_id}" + run_dir = os.path.join(self.args.output_dir, run_name) + else: + run_dir = self.args.output_dir + return run_dir + + def _load_from_checkpoint(self, resume_from_checkpoint, model=None): + if model is None: + model = self.model + + config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) + + weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) + safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) + + if not any( + os.path.isfile(f) for f in [weights_file, safe_weights_file, weights_index_file, safe_weights_index_file] + ): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + + logger.info(f"Loading model from {resume_from_checkpoint}.") + + if os.path.isfile(config_file): + config = PretrainedConfig.from_json_file(config_file) + checkpoint_version = config.transformers_version + if checkpoint_version is not None and checkpoint_version != __version__: + logger.warning( + f"You are resuming training from a checkpoint trained with {checkpoint_version} of " + f"Transformers but your current version is {__version__}. This is not recommended and could " + "yield to errors or unwanted behaviors." + ) + + if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): + # If the model is on the GPU, it still works! + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=resume_from_checkpoint, tag=WEIGHTS_NAME, partial=False, load_optimizer=False + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if hasattr(self.args, "fp16") and self.args.fp16 is True: + logger.warning( + "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." + ) + state_dict = torch.load(weights_file, map_location="cpu") + # Required for smp to not auto-translate state_dict from hf to smp (is already smp). + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + # release memory + del state_dict + elif self.is_fsdp_enabled: + self.accelerator.state.fsdp_plugin.load_model(self.accelerator, model, resume_from_checkpoint) + else: + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(safe_weights_file): + state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") + else: + state_dict = torch.load(weights_file, map_location="cpu") + + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + # release memory + del state_dict + self._issue_warnings_after_load(load_result) + else: + # We load the sharded checkpoint + load_result = load_sharded_checkpoint( + model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + + def _load_best_model(self): + logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if ( + os.path.exists(best_model_path) + or os.path.exists(best_safe_model_path) + or os.path.exists(best_adapter_model_path) + or os.path.exists(best_safe_adapter_model_path) + ): + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) + else: + has_been_loaded = True + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=self.state.best_model_checkpoint, + tag=WEIGHTS_NAME, + partial=False, + load_optimizer=False, + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load(best_model_path, map_location="cpu") + + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + elif self.is_fsdp_enabled: + self.accelerator.state.fsdp_plugin.load_model( + self.accelerator, model, self.state.best_model_checkpoint + ) + else: + if is_peft_available() and isinstance(model, PeftModel): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): + model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) + # Load_adapter has no return value present, modify it when appropriate. + from torch.nn.modules.module import _IncompatibleKeys + + load_result = _IncompatibleKeys([], []) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"using `TrainerCallback` to save {ADAPTER_WEIGHTS_NAME} in corresponding folders, " + "here are some examples https://github.com/huggingface/peft/issues/96" + ) + has_been_loaded = False + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + has_been_loaded = False + else: + # We load the model state dict on the CPU to avoid an OOM error. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load(best_model_path, map_location="cpu") + + # If the model is on the GPU, it still works! + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 + # which takes *args instead of **kwargs + load_result = model.load_state_dict(state_dict, False) + if not is_sagemaker_mp_enabled() and has_been_loaded: + self._issue_warnings_after_load(load_result) + elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): + load_result = load_sharded_checkpoint( + model, self.state.best_model_checkpoint, strict=is_sagemaker_mp_enabled() + ) + if not is_sagemaker_mp_enabled(): + self._issue_warnings_after_load(load_result) + else: + logger.warning( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + def _issue_warnings_after_load(self, load_result): + if len(load_result.missing_keys) != 0: + if self.model._keys_to_ignore_on_save is not None and set(load_result.missing_keys) == set( + self.model._keys_to_ignore_on_save + ): + self.model.tie_weights() + else: + logger.warning(f"There were missing keys in the checkpoint model loaded: {load_result.missing_keys}.") + if len(load_result.unexpected_keys) != 0: + logger.warning( + f"There were unexpected keys in the checkpoint model loaded: {load_result.unexpected_keys}." + ) + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + if self.control.should_log: + if is_torch_tpu_available(): + xm.mark_step() + + logs: Dict[str, float] = {} + + # all_gather + mean() to get average loss over all processes + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + metrics = None + if self.control.should_evaluate: + if isinstance(self.eval_dataset, dict): + metrics = {} + for eval_dataset_name, eval_dataset in self.eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=eval_dataset, + ignore_keys=ignore_keys_for_eval, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + else: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, self.state.global_step, metrics) + + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + self.lr_scheduler.step(metrics[metric_to_check]) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + if checkpoint is None: + return + + if self.args.world_size > 1: + process_index = self.args.process_index + rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") + if not os.path.isfile(rng_file): + logger.info( + f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(rng_file): + logger.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + + checkpoint_rng_state = torch.load(rng_file) + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + if torch.cuda.is_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) + else: + try: + torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) + except Exception as e: + logger.info( + f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" + "\nThis won't yield the same results as if the training had not been interrupted." + ) + if is_torch_tpu_available(): + xm.set_rng_state(checkpoint_rng_state["xla"]) + + def _save_checkpoint(self, model, trial, metrics=None): + # In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we + # want to save except FullyShardedDDP. + # assert unwrap_model(model) is self.model, "internal model should be a reference to self.model" + + # Save model checkpoint + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}" # changed by homeway, 20230711 + + if self.hp_search_backend is None and trial is None: + self.store_flos() + + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + if self.is_deepspeed_enabled: + # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed + # config `stage3_gather_16bit_weights_on_model_save` is True + self.model_wrapped.save_checkpoint(output_dir) + + # Save optimizer and scheduler + if self.sharded_ddp == ShardedDDPOption.SIMPLE: + self.optimizer.consolidate_state_dict() + + if self.fsdp: + # FSDP has a different interface for saving optimizer states. + # Needs to be called on all ranks to gather all states. + # full_optim_state_dict will be deprecated after Pytorch 2.2! + full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) + + if is_torch_tpu_available(): + xm.rendezvous("saving_optimizer_states") + xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: + xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + elif is_sagemaker_mp_enabled(): + opt_state_dict = self.optimizer.local_state_dict(gather_if_shard=False) + smp.barrier() + if smp.rdp_rank() == 0 or smp.state.cfg.shard_optimizer_state: + smp.save( + opt_state_dict, + os.path.join(output_dir, OPTIMIZER_NAME), + partial=True, + v3=smp.state.cfg.shard_optimizer_state, + ) + if self.args.should_save: + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + elif self.args.should_save and not self.is_deepspeed_enabled: + # deepspeed.save_checkpoint above saves model/optim/sched + if self.fsdp: + torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) + else: + torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + + with warnings.catch_warnings(record=True) as caught_warnings: + torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling: + torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + + # Determine the new best metric / best model checkpoint + if metrics is not None and self.args.metric_for_best_model is not None: + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + metric_value = metrics[metric_to_check] + + operator = np.greater if self.args.greater_is_better else np.less + if ( + self.state.best_metric is None + or self.state.best_model_checkpoint is None + or operator(metric_value, self.state.best_metric) + ): + self.state.best_metric = metric_value + self.state.best_model_checkpoint = output_dir + + # Save the Trainer state + if self.args.should_save: + self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) + + # Save RNG state in non-distributed training + rng_states = { + "python": random.getstate(), + "numpy": np.random.get_state(), + "cpu": torch.random.get_rng_state(), + } + if torch.cuda.is_available(): + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) + rng_states["cuda"] = torch.cuda.random.get_rng_state_all() + else: + rng_states["cuda"] = torch.cuda.random.get_rng_state() + + if is_torch_tpu_available(): + rng_states["xla"] = xm.get_rng_state() + + # A process can arrive here before the process 0 has a chance to save the model, in which case output_dir may + # not yet exist. + os.makedirs(output_dir, exist_ok=True) + + if self.args.world_size <= 1: + torch.save(rng_states, os.path.join(output_dir, "rng_state.pth")) + else: + torch.save(rng_states, os.path.join(output_dir, f"rng_state_{self.args.process_index}.pth")) + + if self.args.push_to_hub: + self._push_from_checkpoint(output_dir) + + # Maybe delete some older checkpoints. + if self.args.should_save: + self._rotate_checkpoints(use_mtime=True, output_dir=run_dir) + + def _load_optimizer_and_scheduler(self, checkpoint): + """If optimizer and scheduler states exist, load them.""" + if checkpoint is None: + return + + if self.is_deepspeed_enabled: + # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init + return + + checkpoint_file_exists = ( + glob.glob(os.path.join(checkpoint, OPTIMIZER_NAME) + "_*") + if is_sagemaker_mp_enabled() + else os.path.isfile(os.path.join(checkpoint, OPTIMIZER_NAME)) + ) + if checkpoint_file_exists and os.path.isfile(os.path.join(checkpoint, SCHEDULER_NAME)): + # Load in optimizer and scheduler states + if is_torch_tpu_available(): + # On TPU we have to take some extra precautions to properly load the states on the right device. + optimizer_state = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location="cpu") + with warnings.catch_warnings(record=True) as caught_warnings: + lr_scheduler_state = torch.load(os.path.join(checkpoint, SCHEDULER_NAME), map_location="cpu") + reissue_pt_warnings(caught_warnings) + + xm.send_cpu_data_to_device(optimizer_state, self.args.device) + xm.send_cpu_data_to_device(lr_scheduler_state, self.args.device) + + self.optimizer.load_state_dict(optimizer_state) + self.lr_scheduler.load_state_dict(lr_scheduler_state) + else: + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): + # Optimizer checkpoint was saved with smp >= 1.10 + def opt_load_hook(mod, opt): + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + else: + # Optimizer checkpoint was saved with smp < 1.10 + def opt_load_hook(mod, opt): + if IS_SAGEMAKER_MP_POST_1_10: + opt.load_state_dict( + smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True, back_compat=True) + ) + else: + opt.load_state_dict(smp.load(os.path.join(checkpoint, OPTIMIZER_NAME), partial=True)) + + self.model_wrapped.register_post_step_hook(opt_load_hook) + else: + # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. + # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more + # likely to get OOM on CPU (since we load num_gpu times the optimizer state + map_location = self.args.device if self.args.world_size > 1 else "cpu" + if self.fsdp: + full_osd = None + # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it + if self.args.process_index == 0: + full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) + # call scatter_full_optim_state_dict on all ranks + sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) + self.optimizer.load_state_dict(sharded_osd) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) + with warnings.catch_warnings(record=True) as caught_warnings: + self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) + reissue_pt_warnings(caught_warnings) + if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): + self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) + + def hyperparameter_search( + self, + hp_space: Optional[Callable[["optuna.Trial"], Dict[str, float]]] = None, + compute_objective: Optional[Callable[[Dict[str, float]], float]] = None, + n_trials: int = 20, + direction: str = "minimize", + backend: Optional[Union["str", HPSearchBackend]] = None, + hp_name: Optional[Callable[["optuna.Trial"], str]] = None, + **kwargs, + ) -> BestRun: + """ + Launch an hyperparameter search using `optuna` or `Ray Tune` or `SigOpt`. The optimized quantity is determined + by `compute_objective`, which defaults to a function returning the evaluation loss when no metric is provided, + the sum of all metrics otherwise. + + + + To use this method, you need to have provided a `model_init` when initializing your [`Trainer`]: we need to + reinitialize the model at each new run. This is incompatible with the `optimizers` argument, so you need to + subclass [`Trainer`] and override the method [`~Trainer.create_optimizer_and_scheduler`] for custom + optimizer/scheduler. + + + + Args: + hp_space (`Callable[["optuna.Trial"], Dict[str, float]]`, *optional*): + A function that defines the hyperparameter search space. Will default to + [`~trainer_utils.default_hp_space_optuna`] or [`~trainer_utils.default_hp_space_ray`] or + [`~trainer_utils.default_hp_space_sigopt`] depending on your backend. + compute_objective (`Callable[[Dict[str, float]], float]`, *optional*): + A function computing the objective to minimize or maximize from the metrics returned by the `evaluate` + method. Will default to [`~trainer_utils.default_compute_objective`]. + n_trials (`int`, *optional*, defaults to 100): + The number of trial runs to test. + direction (`str`, *optional*, defaults to `"minimize"`): + Whether to optimize greater or lower objects. Can be `"minimize"` or `"maximize"`, you should pick + `"minimize"` when optimizing the validation loss, `"maximize"` when optimizing one or several metrics. + backend (`str` or [`~training_utils.HPSearchBackend`], *optional*): + The backend to use for hyperparameter search. Will default to optuna or Ray Tune or SigOpt, depending + on which one is installed. If all are installed, will default to optuna. + hp_name (`Callable[["optuna.Trial"], str]]`, *optional*): + A function that defines the trial/run name. Will default to None. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to `optuna.create_study` or `ray.tune.run`. For more + information see: + + - the documentation of + [optuna.create_study](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.create_study.html) + - the documentation of [tune.run](https://docs.ray.io/en/latest/tune/api_docs/execution.html#tune-run) + - the documentation of [sigopt](https://app.sigopt.com/docs/endpoints/experiments/create) + + Returns: + [`trainer_utils.BestRun`]: All the information about the best run. Experiment summary can be found in + `run_summary` attribute for Ray backend. + """ + if backend is None: + backend = default_hp_search_backend() + if backend is None: + raise RuntimeError( + "At least one of optuna or ray should be installed. " + "To install optuna run `pip install optuna`. " + "To install ray run `pip install ray[tune]`. " + "To install sigopt run `pip install sigopt`." + ) + backend = HPSearchBackend(backend) + if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): + raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") + if backend == HPSearchBackend.RAY and not is_ray_tune_available(): + raise RuntimeError( + "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." + ) + if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): + raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") + if backend == HPSearchBackend.WANDB and not is_wandb_available(): + raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.") + self.hp_search_backend = backend + if self.model_init is None: + raise RuntimeError( + "To use hyperparameter search, you need to pass your model through a model_init function." + ) + + self.hp_space = default_hp_space[backend] if hp_space is None else hp_space + self.hp_name = hp_name + self.compute_objective = default_compute_objective if compute_objective is None else compute_objective + + backend_dict = { + HPSearchBackend.OPTUNA: run_hp_search_optuna, + HPSearchBackend.RAY: run_hp_search_ray, + HPSearchBackend.SIGOPT: run_hp_search_sigopt, + HPSearchBackend.WANDB: run_hp_search_wandb, + } + best_run = backend_dict[backend](self, n_trials, direction, **kwargs) + + self.hp_search_backend = None + return best_run + + def log(self, logs: Dict[str, float]) -> None: + """ + Log `logs` on the various objects watching training. + + Subclass and override this method to inject custom behavior. + + Args: + logs (`Dict[str, float]`): + The values to log. + """ + if self.state.epoch is not None: + logs["epoch"] = round(self.state.epoch, 2) + + output = {**logs, **{"step": self.state.global_step}} + self.state.log_history.append(output) + self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) + + def _prepare_input(self, data: Union[torch.Tensor, Any]) -> Union[torch.Tensor, Any]: + """ + Prepares one `data` before feeding it to the model, be it a tensor or a nested list/dictionary of tensors. + """ + if isinstance(data, Mapping): + return type(data)({k: self._prepare_input(v) for k, v in data.items()}) + elif isinstance(data, (tuple, list)): + return type(data)(self._prepare_input(v) for v in data) + elif isinstance(data, torch.Tensor): + kwargs = {"device": self.args.device} + if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): + # NLP models inputs are int/uint and those get adjusted to the right dtype of the + # embedding. Other models such as wav2vec2's inputs are already float and thus + # may need special handling to match the dtypes of the model + kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) + return data.to(**kwargs) + return data + + def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Prepare `inputs` before feeding them to the model, converting them to tensors if they are not already and + handling potential state. + """ + inputs = self._prepare_input(inputs) + if len(inputs) == 0: + raise ValueError( + "The batch received was empty, your model won't be able to train on it. Double-check that your " + f"training dataset contains keys expected by the model: {','.join(self._signature_columns)}." + ) + if self.args.past_index >= 0 and self._past is not None: + inputs["mems"] = self._past + + return inputs + + def compute_loss_context_manager(self): + """ + A helper wrapper to group together context managers. + """ + return self.autocast_smart_context_manager() + + def autocast_smart_context_manager(self, cache_enabled: Optional[bool] = True): + """ + A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired + arguments, depending on the situation. + """ + if self.use_cuda_amp or self.use_cpu_amp: + if is_torch_greater_or_equal_than_1_10: + ctx_manager = ( + torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + if self.use_cpu_amp + else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + ) + else: + ctx_manager = torch.cuda.amp.autocast() + else: + ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() + + return ctx_manager + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + """ + Perform a training step on a batch of inputs. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to train. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + + Return: + `torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + inputs = self._prepare_inputs(inputs) + + if is_sagemaker_mp_enabled(): + loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) + return loss_mb.reduce_mean().detach().to(self.args.device) + + with self.compute_loss_context_manager(): + loss = self.compute_loss(model, inputs) + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.do_grad_scaling: + self.scaler.scale(loss).backward() + elif self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss) + + return loss.detach() / self.args.gradient_accumulation_steps + + def compute_loss(self, model, inputs, return_outputs=False): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + if self.label_smoother is not None and "labels" in inputs: + labels = inputs.pop("labels") + else: + labels = None + outputs = model(**inputs) + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index] + + if labels is not None: + if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + loss = self.label_smoother(outputs, labels, shift_labels=True) + else: + loss = self.label_smoother(outputs, labels) + else: + if isinstance(outputs, dict) and "loss" not in outputs: + raise ValueError( + "The model did not return a loss from the inputs, only the following keys: " + f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." + ) + # We don't use .loss here since the model may return tuples instead of ModelOutput. + loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0] + + return (loss, outputs) if return_outputs else loss + + def is_local_process_zero(self) -> bool: + """ + Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several + machines) main process. + """ + return self.args.local_process_index == 0 + + def is_world_process_zero(self) -> bool: + """ + Whether or not this process is the global main process (when training in a distributed fashion on several + machines, this is only going to be `True` for one process). + """ + # Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global + # process index. + if is_sagemaker_mp_enabled(): + return smp.rank() == 0 + else: + return self.args.process_index == 0 + + def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): + """ + Will save the model, so you can reload it using `from_pretrained()`. + + Will only save from the main process. + """ + + if output_dir is None: + output_dir = self.args.output_dir + + if is_torch_tpu_available(): + self._save_tpu(output_dir) + elif is_sagemaker_mp_enabled(): + # Calling the state_dict needs to be done on the wrapped model and on all processes. + os.makedirs(output_dir, exist_ok=True) + state_dict = self.model_wrapped.state_dict() + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + if IS_SAGEMAKER_MP_POST_1_10: + # 'user_content.pt' indicates model state_dict saved with smp >= 1.10 + Path(os.path.join(output_dir, "user_content.pt")).touch() + elif ( + ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp + or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp + or self.fsdp is not None + or self.is_fsdp_enabled + ): + if self.is_fsdp_enabled: + os.makedirs(output_dir, exist_ok=True) + self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir) + else: + state_dict = self.model.state_dict() + + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + elif self.is_deepspeed_enabled: + # this takes care of everything as long as we aren't under zero3 + if self.args.should_save: + self._save(output_dir) + + if is_deepspeed_zero3_enabled(): + # It's too complicated to try to override different places where the weights dump gets + # saved, so since under zero3 the file is bogus, simply delete it. The user should + # either user deepspeed checkpoint to resume or to recover full weights use + # zero_to_fp32.py stored in the checkpoint. + if self.args.should_save: + file = os.path.join(output_dir, WEIGHTS_NAME) + if os.path.isfile(file): + # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights") + os.remove(file) + + # now save the real model if stage3_gather_16bit_weights_on_model_save=True + # if false it will not be saved. + # This must be called on all ranks + if not self.model_wrapped.save_16bit_model(output_dir, WEIGHTS_NAME): + logger.warning( + "deepspeed.save_16bit_model didn't save the model, since" + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self.model_wrapped.save_checkpoint(output_dir) + + elif self.args.should_save: + self._save(output_dir) + + # Push to the Hub when `save_model` is called by the user. + if self.args.push_to_hub and not _internal_call: + self.push_to_hub(commit_message="Model save") + + def _save_tpu(self, output_dir: Optional[str] = None): + output_dir = output_dir if output_dir is not None else self.args.output_dir + logger.info(f"Saving model checkpoint to {output_dir}") + + if xm.is_master_ordinal(): + os.makedirs(output_dir, exist_ok=True) + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + xm.rendezvous("saving_checkpoint") + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + unwrap_model(self.model).save_pretrained( + output_dir, + is_main_process=self.args.should_save, + state_dict=self.model.state_dict(), + save_function=xm.save, + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + state_dict = self.model.state_dict() + xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save) + if self.tokenizer is not None and self.args.should_save: + self.tokenizer.save_pretrained(output_dir) + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + + if isinstance(unwrap_model(self.model), supported_classes): + unwrap_model(self.model).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if self.args.save_safetensors: + safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def store_flos(self): + # Storing the number of floating-point operations that went into the model + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + self.state.total_flos += ( + distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() + ) + self.current_flos = 0 + else: + self.state.total_flos += self.current_flos + self.current_flos = 0 + + def _sorted_checkpoints( + self, output_dir=None, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False + ) -> List[str]: + ordering_and_checkpoint_path = [] + + glob_checkpoints = [str(x) for x in Path(output_dir).glob(f"{checkpoint_prefix}-*") if os.path.isdir(x)] + + for path in glob_checkpoints: + if use_mtime: + ordering_and_checkpoint_path.append((os.path.getmtime(path), path)) + else: + regex_match = re.match(f".*{checkpoint_prefix}-([0-9]+)", path) + if regex_match is not None and regex_match.groups() is not None: + ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path)) + + checkpoints_sorted = sorted(ordering_and_checkpoint_path) + checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted] + # Make sure we don't delete the best model. + if self.state.best_model_checkpoint is not None: + best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint))) + for i in range(best_model_index, len(checkpoints_sorted) - 2): + checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i] + return checkpoints_sorted + + def _rotate_checkpoints(self, use_mtime=False, output_dir=None) -> None: + if self.args.save_total_limit is None or self.args.save_total_limit <= 0: + return + + # Check if we should delete older checkpoint(s) + checkpoints_sorted = self._sorted_checkpoints(use_mtime=use_mtime, output_dir=output_dir) + if len(checkpoints_sorted) <= self.args.save_total_limit: + return + + # If save_total_limit=1 with load_best_model_at_end=True, we could end up deleting the last checkpoint, which + # we don't do to allow resuming. + save_total_limit = self.args.save_total_limit + if ( + self.state.best_model_checkpoint is not None + and self.args.save_total_limit == 1 + and checkpoints_sorted[-1] != self.state.best_model_checkpoint + ): + save_total_limit = 2 + + number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - save_total_limit) + checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete] + for checkpoint in checkpoints_to_be_deleted: + logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") + shutil.rmtree(checkpoint, ignore_errors=True) + + def evaluate( + self, + eval_dataset: Optional[Dataset] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> Dict[str, float]: + """ + Run evaluation and returns metrics. + + The calling script will be responsible for providing a method to compute metrics, as they are task-dependent + (pass it to the init `compute_metrics` argument). + + You can also subclass and override this method to inject custom behavior. + + Args: + eval_dataset (`Dataset`, *optional*): + Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns + not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` + method. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"eval"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "eval_bleu" if the prefix is "eval" (default) + + Returns: + A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The + dictionary also contains the epoch number which comes from the training state. + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + eval_dataloader = self.get_eval_dataloader(eval_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if self.compute_metrics is None else None, + ignore_keys=ignore_keys, + metric_key_prefix=metric_key_prefix, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return output.metrics + + def predict( + self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test" + ) -> PredictionOutput: + """ + Run prediction and returns predictions and potential metrics. + + Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method + will also return metrics, like in `evaluate()`. + + Args: + test_dataset (`Dataset`): + Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the + `model.forward()` method are automatically removed. Has to implement the method `__len__` + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + metric_key_prefix (`str`, *optional*, defaults to `"test"`): + An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named + "test_bleu" if the prefix is "test" (default) + + + + If your predictions or labels have different sequence length (for instance because you're doing dynamic padding + in a token classification task) the predictions will be padded (on the right) to allow for concatenation into + one array. The padding index is -100. + + + + Returns: *NamedTuple* A namedtuple with the following keys: + + - predictions (`np.ndarray`): The predictions on `test_dataset`. + - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). + - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained + labels). + """ + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + test_dataloader = self.get_test_dataloader(test_dataset) + start_time = time.time() + + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + output = eval_loop( + test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix + ) + total_batch_size = self.args.eval_batch_size * self.args.world_size + if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: + start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.control = self.callback_handler.on_predict(self.args, self.state, self.control, output.metrics) + self._memory_tracker.stop_and_update_metrics(output.metrics) + + return PredictionOutput(predictions=output.predictions, label_ids=output.label_ids, metrics=output.metrics) + + def evaluation_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.model_wrapped is self.model: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = self.args.eval_batch_size + + logger.info(f"***** Running {description} *****") + if has_length(dataloader): + logger.info(f" Num examples = {self.num_examples(dataloader)}") + else: + logger.info(" Num examples: Unknown") + logger.info(f" Batch size = {batch_size}") + + model.eval() + + self.callback_handler.eval_dataloader = dataloader + # Do this before wrapping. + eval_dataset = getattr(dataloader, "dataset", None) + + if is_torch_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) + + if args.past_index >= 0: + self._past = None + + # Initialize containers + # losses/preds/labels on GPU/TPU (accumulated for eval_accumulation_steps) + losses_host = None + preds_host = None + labels_host = None + inputs_host = None + + # losses/preds/labels on CPU (final containers) + all_losses = None + all_preds = None + all_labels = None + all_inputs = None + # Will be useful when we have an iterable dataset so don't know its length. + + observed_num_examples = 0 + # Main evaluation loop + for step, inputs in enumerate(dataloader): + # Update the observed num examples + observed_batch_size = find_batch_size(inputs) + if observed_batch_size is not None: + observed_num_examples += observed_batch_size + # For batch samplers, batch_size is not known by the dataloader in advance. + if batch_size is None: + batch_size = observed_batch_size + + # Prediction step + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + if is_torch_tpu_available(): + xm.mark_step() + + # Update containers on host + if loss is not None: + losses = self._nested_gather(loss.repeat(batch_size)) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if labels is not None: + labels = self._pad_across_processes(labels) + if inputs_decode is not None: + inputs_decode = self._pad_across_processes(inputs_decode) + inputs_decode = self._nested_gather(inputs_decode) + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + if logits is not None: + logits = self._pad_across_processes(logits) + if self.preprocess_logits_for_metrics is not None: + logits = self.preprocess_logits_for_metrics(logits, labels) + logits = self._nested_gather(logits) + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + if labels is not None: + labels = self._nested_gather(labels) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode + if all_inputs is None + else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = ( + labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + ) + + # Set back to None to begin a new accumulation + losses_host, preds_host, inputs_host, labels_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + if losses_host is not None: + losses = nested_numpify(losses_host) + all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) + if preds_host is not None: + logits = nested_numpify(preds_host) + all_preds = logits if all_preds is None else nested_concat(all_preds, logits, padding_index=-100) + if inputs_host is not None: + inputs_decode = nested_numpify(inputs_host) + all_inputs = ( + inputs_decode if all_inputs is None else nested_concat(all_inputs, inputs_decode, padding_index=-100) + ) + if labels_host is not None: + labels = nested_numpify(labels_host) + all_labels = labels if all_labels is None else nested_concat(all_labels, labels, padding_index=-100) + + # Number of samples + if has_length(eval_dataset): + num_samples = len(eval_dataset) + # The instance check is weird and does not actually check for the type, but whether the dataset has the right + # methods. Therefore we need to make sure it also has the attribute. + elif isinstance(eval_dataset, IterableDatasetShard) and getattr(eval_dataset, "num_examples", 0) > 0: + num_samples = eval_dataset.num_examples + else: + if has_length(dataloader): + num_samples = self.num_examples(dataloader) + else: # both len(dataloader.dataset) and len(dataloader) fail + num_samples = observed_num_examples + if num_samples == 0 and observed_num_examples > 0: + num_samples = observed_num_examples + + # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of + # samplers has been rounded to a multiple of batch_size, so we truncate. + if all_losses is not None: + all_losses = all_losses[:num_samples] + if all_preds is not None: + all_preds = nested_truncate(all_preds, num_samples) + if all_labels is not None: + all_labels = nested_truncate(all_labels, num_samples) + if all_inputs is not None: + all_inputs = nested_truncate(all_inputs, num_samples) + + # Metrics! + if self.compute_metrics is not None and all_preds is not None and all_labels is not None: + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=all_preds, label_ids=all_labels, inputs=all_inputs) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=all_preds, label_ids=all_labels)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if all_losses is not None: + metrics[f"{metric_key_prefix}_loss"] = all_losses.mean().item() + if hasattr(self, "jit_compilation_time"): + metrics[f"{metric_key_prefix}_jit_compilation_time"] = self.jit_compilation_time + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples) + + def _nested_gather(self, tensors, name=None): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + if name is None: + name = "nested_gather" + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( + self.args.distributed_state is None and self.local_rank != -1 + ): + tensors = distributed_concat(tensors) + return tensors + + # Copied from Accelerate. + def _pad_across_processes(self, tensor, pad_index=-100): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so + they can safely be gathered. + """ + if isinstance(tensor, (list, tuple)): + return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) + elif isinstance(tensor, dict): + return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) + elif not isinstance(tensor, torch.Tensor): + raise TypeError( + f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." + ) + + if len(tensor.shape) < 2: + return tensor + # Gather all sizes + size = torch.tensor(tensor.shape, device=tensor.device)[None] + sizes = self._nested_gather(size).cpu() + + max_size = max(s[1] for s in sizes) + # When extracting XLA graphs for compilation, max_size is 0, + # so use inequality to avoid errors. + if tensor.shape[1] >= max_size: + return tensor + + # Then pad to the maximum size + old_size = tensor.shape + new_size = list(old_size) + new_size[1] = max_size + new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index + new_tensor[:, : old_size[1]] = tensor + return new_tensor + + def prediction_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform an evaluation step on `model` using `inputs`. + + Subclass and override to inject custom behavior. + + Args: + model (`nn.Module`): + The model to evaluate. + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument `labels`. Check your model's documentation for all accepted arguments. + prediction_loss_only (`bool`): + Whether or not to return the loss only. + ignore_keys (`List[str]`, *optional*): + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions. + + Return: + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, + logits and labels (each being optional). + """ + has_labels = False if len(self.label_names) == 0 else all(inputs.get(k) is not None for k in self.label_names) + # For CLIP-like models capable of returning loss values. + # If `return_loss` is not specified or being `None` in `inputs`, we check if the default value of `return_loss` + # is `True` in `model.forward`. + return_loss = inputs.get("return_loss", None) + if return_loss is None: + return_loss = self.can_return_loss + loss_without_labels = True if len(self.label_names) == 0 and return_loss else False + + inputs = self._prepare_inputs(inputs) + if ignore_keys is None: + if hasattr(self.model, "config"): + ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", []) + else: + ignore_keys = [] + + # labels may be popped when computing the loss (label smoothing for instance) so we grab them first. + if has_labels or loss_without_labels: + labels = nested_detach(tuple(inputs.get(name) for name in self.label_names)) + if len(labels) == 1: + labels = labels[0] + else: + labels = None + + with torch.no_grad(): + if is_sagemaker_mp_enabled(): + raw_outputs = smp_forward_only(model, inputs) + if has_labels or loss_without_labels: + if isinstance(raw_outputs, dict): + loss_mb = raw_outputs["loss"] + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys + ["loss"]) + else: + loss_mb = raw_outputs[0] + logits_mb = raw_outputs[1:] + + loss = loss_mb.reduce_mean().detach().cpu() + logits = smp_nested_concat(logits_mb) + else: + loss = None + if isinstance(raw_outputs, dict): + logits_mb = tuple(v for k, v in raw_outputs.items() if k not in ignore_keys) + else: + logits_mb = raw_outputs + logits = smp_nested_concat(logits_mb) + else: + if has_labels or loss_without_labels: + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, inputs, return_outputs=True) + loss = loss.mean().detach() + + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) + else: + logits = outputs[1:] + else: + loss = None + with self.compute_loss_context_manager(): + outputs = model(**inputs) + if isinstance(outputs, dict): + logits = tuple(v for k, v in outputs.items() if k not in ignore_keys) + else: + logits = outputs + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[self.args.past_index - 1] + + if prediction_loss_only: + return (loss, None, None) + + logits = nested_detach(logits) + if len(logits) == 1: + logits = logits[0] + return (loss, logits, labels) + + def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]): + """ + For models that inherit from [`PreTrainedModel`], uses that method to compute the number of floating point + operations for every backward + forward pass. If using another model, either implement such a method in the + model or subclass and override this method. + + Args: + inputs (`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + Returns: + `int`: The number of floating-point operations. + """ + if hasattr(self.model, "floating_point_ops"): + return self.model.floating_point_ops(inputs) + else: + return 0 + + def init_git_repo(self, at_init: bool = False): + """ + Initializes a git repo in `self.args.hub_model_id`. + + Args: + at_init (`bool`, *optional*, defaults to `False`): + Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is + `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped + out. + """ + if not self.is_world_process_zero(): + return + if self.args.hub_model_id is None: + repo_name = Path(self.args.output_dir).absolute().name + else: + repo_name = self.args.hub_model_id + if "/" not in repo_name: + repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) + + # Make sure the repo exists. + create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) + try: + self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + except EnvironmentError: + if self.args.overwrite_output_dir and at_init: + # Try again after wiping output_dir + shutil.rmtree(self.args.output_dir) + self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + else: + raise + + self.repo.git_pull() + + # By default, ignore the checkpoint folders + if ( + not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")) + and self.args.hub_strategy != HubStrategy.ALL_CHECKPOINTS + ): + with open(os.path.join(self.args.output_dir, ".gitignore"), "w", encoding="utf-8") as writer: + writer.writelines(["checkpoint-*/"]) + + # Add "*.sagemaker" to .gitignore if using SageMaker + if os.environ.get("SM_TRAINING_ENV"): + self._add_sm_patterns_to_gitignore() + + self.push_in_progress = None + + def create_model_card( + self, + language: Optional[str] = None, + license: Optional[str] = None, + tags: Union[str, List[str], None] = None, + model_name: Optional[str] = None, + finetuned_from: Optional[str] = None, + tasks: Union[str, List[str], None] = None, + dataset_tags: Union[str, List[str], None] = None, + dataset: Union[str, List[str], None] = None, + dataset_args: Union[str, List[str], None] = None, + ): + """ + Creates a draft of a model card using the information available to the `Trainer`. + + Args: + language (`str`, *optional*): + The language of the model (if applicable) + license (`str`, *optional*): + The license of the model. Will default to the license of the pretrained model used, if the original + model given to the `Trainer` comes from a repo on the Hub. + tags (`str` or `List[str]`, *optional*): + Some tags to be included in the metadata of the model card. + model_name (`str`, *optional*): + The name of the model. + finetuned_from (`str`, *optional*): + The name of the model used to fine-tune this one (if applicable). Will default to the name of the repo + of the original model given to the `Trainer` (if it comes from the Hub). + tasks (`str` or `List[str]`, *optional*): + One or several task identifiers, to be included in the metadata of the model card. + dataset_tags (`str` or `List[str]`, *optional*): + One or several dataset tags, to be included in the metadata of the model card. + dataset (`str` or `List[str]`, *optional*): + One or several dataset identifiers, to be included in the metadata of the model card. + dataset_args (`str` or `List[str]`, *optional*): + One or several dataset arguments, to be included in the metadata of the model card. + """ + if not self.is_world_process_zero(): + return + + training_summary = TrainingSummary.from_trainer( + self, + language=language, + license=license, + tags=tags, + model_name=model_name, + finetuned_from=finetuned_from, + tasks=tasks, + dataset_tags=dataset_tags, + dataset=dataset, + dataset_args=dataset_args, + ) + model_card = training_summary.to_model_card() + with open(os.path.join(self.args.output_dir, "README.md"), "w") as f: + f.write(model_card) + + def _push_from_checkpoint(self, checkpoint_folder): + # Only push from one node. + if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: + return + # If we haven't finished the last push, we don't do this one. + if self.push_in_progress is not None and not self.push_in_progress.is_done: + return + + output_dir = self.args.output_dir + # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder + modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] + for modeling_file in modeling_files: + if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): + shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) + # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure. + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + # Same for the training arguments + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + try: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Temporarily move the checkpoint just saved for the push + tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") + # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a + # subfolder. + if os.path.isdir(tmp_checkpoint): + shutil.rmtree(tmp_checkpoint) + shutil.move(checkpoint_folder, tmp_checkpoint) + + if self.args.save_strategy == IntervalStrategy.STEPS: + commit_message = f"Training in progress, step {self.state.global_step}" + else: + commit_message = f"Training in progress, epoch {int(self.state.epoch)}" + push_work = self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True) + # Return type of `Repository.push_to_hub` is either None or a tuple. + if push_work is not None: + self.push_in_progress = push_work[1] + except Exception as e: + logger.error(f"Error when pushing to hub: {e}") + finally: + if self.args.hub_strategy == HubStrategy.CHECKPOINT: + # Move back the checkpoint to its place + shutil.move(tmp_checkpoint, checkpoint_folder) + + def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: + """ + Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + + Parameters: + commit_message (`str`, *optional*, defaults to `"End of training"`): + Message to commit while pushing. + blocking (`bool`, *optional*, defaults to `True`): + Whether the function should return only when the `git push` has finished. + kwargs: + Additional keyword arguments passed along to [`~Trainer.create_model_card`]. + + Returns: + The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of + the commit and an object to track the progress of the commit if `blocking=True` + """ + # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but + # it might fail. + if not hasattr(self, "repo"): + self.init_git_repo() + + model_name = kwargs.pop("model_name", None) + if model_name is None and self.args.should_save: + if self.args.hub_model_id is None: + model_name = Path(self.args.output_dir).name + else: + model_name = self.args.hub_model_id.split("/")[-1] + + # Needs to be executed on all processes for TPU training, but will only save on the processed determined by + # self.args.should_save. + self.save_model(_internal_call=True) + + # Only push from one node. + if not self.is_world_process_zero(): + return + + # Cancel any async push in progress if blocking=True. The commits will all be pushed together. + if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: + self.push_in_progress._process.kill() + self.push_in_progress = None + + git_head_commit_url = self.repo.push_to_hub( + commit_message=commit_message, blocking=blocking, auto_lfs_prune=True + ) + # push separately the model card to be independant from the rest of the model + if self.args.should_save: + self.create_model_card(model_name=model_name, **kwargs) + try: + self.repo.push_to_hub( + commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True + ) + except EnvironmentError as exc: + logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") + + return git_head_commit_url + + # + # Deprecated code + # + + def prediction_loop( + self, + dataloader: DataLoader, + description: str, + prediction_loss_only: Optional[bool] = None, + ignore_keys: Optional[List[str]] = None, + metric_key_prefix: str = "eval", + ) -> EvalLoopOutput: + """ + Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`. + + Works both with or without labels. + """ + args = self.args + + if not has_length(dataloader): + raise ValueError("dataloader must implement a working __len__") + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only + + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.model_wrapped is self.model: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) + + model = self._wrap_model(self.model, training=False, dataloader=dataloader) + + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) + + batch_size = dataloader.batch_size + num_examples = self.num_examples(dataloader) + logger.info(f"***** Running {description} *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Batch size = {batch_size}") + losses_host: torch.Tensor = None + preds_host: Union[torch.Tensor, List[torch.Tensor]] = None + labels_host: Union[torch.Tensor, List[torch.Tensor]] = None + inputs_host: Union[torch.Tensor, List[torch.Tensor]] = None + + world_size = max(1, args.world_size) + + eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) + if not prediction_loss_only: + # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass + # a batch size to the sampler) + make_multiple_of = None + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler): + make_multiple_of = dataloader.sampler.batch_size + preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + inputs_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of) + + model.eval() + + if is_torch_tpu_available(): + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) + + if args.past_index >= 0: + self._past = None + + self.callback_handler.eval_dataloader = dataloader + + for step, inputs in enumerate(dataloader): + loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) + inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + + if loss is not None: + losses = loss.repeat(batch_size) + losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + if logits is not None: + preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + if labels is not None: + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + if inputs_decode is not None: + inputs_host = ( + inputs_decode + if inputs_host is None + else nested_concat(inputs_host, inputs_decode, padding_index=-100) + ) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) + + # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + # Set back to None to begin a new accumulation + losses_host, preds_host, labels_host, inputs_host = None, None, None, None + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of the evaluation loop + delattr(self, "_past") + + # Gather all remaining tensors and put them back on the CPU + eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) + if not prediction_loss_only: + preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) + labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids")) + inputs_gatherer.add_arrays(self._gather_and_numpify(inputs_host, "eval_inputs_ids")) + + eval_loss = eval_losses_gatherer.finalize() + preds = preds_gatherer.finalize() if not prediction_loss_only else None + label_ids = labels_gatherer.finalize() if not prediction_loss_only else None + inputs_ids = inputs_gatherer.finalize() if not prediction_loss_only else None + + if self.compute_metrics is not None and preds is not None and label_ids is not None: + if args.include_inputs_for_metrics: + metrics = self.compute_metrics( + EvalPrediction(predictions=preds, label_ids=label_ids, inputs=inputs_ids) + ) + else: + metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids)) + else: + metrics = {} + + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + + if eval_loss is not None: + metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return EvalLoopOutput(predictions=preds, label_ids=label_ids, metrics=metrics, num_samples=num_examples) + + def _gather_and_numpify(self, tensors, name): + """ + Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before + concatenating them to `gathered` + """ + if tensors is None: + return + if is_torch_tpu_available(): + tensors = nested_xla_mesh_reduce(tensors, name) + elif is_sagemaker_mp_enabled(): + tensors = smp_gather(tensors) + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + tensors = distributed_concat(tensors) + + return nested_numpify(tensors) + + def _add_sm_patterns_to_gitignore(self) -> None: + """Add SageMaker Checkpointing patterns to .gitignore file.""" + # Make sure we only do this on the main process + if not self.is_world_process_zero(): + return + + patterns = ["*.sagemaker-uploading", "*.sagemaker-uploaded"] + + # Get current .gitignore content + if os.path.exists(os.path.join(self.repo.local_dir, ".gitignore")): + with open(os.path.join(self.repo.local_dir, ".gitignore"), "r") as f: + current_content = f.read() + else: + current_content = "" + + # Add the patterns to .gitignore + content = current_content + for pattern in patterns: + if pattern not in content: + if content.endswith("\n"): + content += pattern + else: + content += f"\n{pattern}" + + # Write the .gitignore file if it has changed + if content != current_content: + with open(os.path.join(self.repo.local_dir, ".gitignore"), "w") as f: + logger.debug(f"Writing .gitignore file. Content: {content}") + f.write(content) + + self.repo.git_add(".gitignore") + + # avoid race condition with git status + time.sleep(0.5) + + if not self.repo.is_repo_clean(): + self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") + self.repo.git_push() + + def create_accelerator_and_postprocess(self): + # create accelerator object + self.accelerator = Accelerator( + deepspeed_plugin=self.args.deepspeed_plugin, + gradient_accumulation_steps=self.args.gradient_accumulation_steps, + ) + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False) + fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False) + + if self.is_deepspeed_enabled: + if getattr(self.args, "hf_deepspeed_config", None) is None: + from transformers.deepspeed import HfTrainerDeepSpeedConfig + + ds_plugin = self.accelerator.state.deepspeed_plugin + + ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) + ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config + ds_plugin.hf_ds_config.trainer_config_process(self.args) diff --git a/soft_prompt/training/trainer_base.py b/soft_prompt/training/trainer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..319146ea086969ab50552ec7acbc324117fd9e5e --- /dev/null +++ b/soft_prompt/training/trainer_base.py @@ -0,0 +1,418 @@ +import logging +import math +import os +import json +import torch +from typing import Dict +import numpy as np +from datetime import datetime, timedelta, timezone +SHA_TZ = timezone( + timedelta(hours=8), + name='Asia/Shanghai', +) +import os.path as osp +from transformers.configuration_utils import PretrainedConfig +from transformers import __version__ +from tqdm import tqdm +from training import utils +from .trainer import Trainer + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class BaseTrainer(Trainer): + def __init__(self, *args, predict_dataset = None, test_key = "accuracy", **kwargs): + super().__init__(*args, **kwargs) + self.config = self.model.config + self.device = next(self.model.parameters()).device + + self.predict_dataset = predict_dataset + self.test_key = test_key + self.best_metrics = { + "best_epoch": 0, + f"best_eval_{self.test_key}": 0, + "best_asr": 0.0, + "best_score": -np.inf, + "best_trigger": [], + "curr_epoch": 0, + "curr_asr": 0.0, + "curr_score": -np.inf, + f"curr_eval_{self.test_key}": 0, + } + + # watermark default config + self.train_steps = 0 + self.trigger_ids = torch.tensor(self.model_wrapped.config.trigger, device=self.device).long() + self.best_trigger_ids = self.trigger_ids.clone() + print("-> [Trainer] start from trigger_ids", self.trigger_ids) + + # random select poison index + if self.train_dataset is not None: + d = self.get_train_dataloader() + self.steps_size = len(d) + self.poison_idx = d.dataset.poison_idx + + self.clean_labels = torch.tensor(self.args.clean_labels).long() + self.target_labels = torch.tensor(self.args.target_labels).long() + assert len(self.target_labels[0]) == len(self.clean_labels[0]) + self.eval_memory = { + "ben_attentions": [], + "wmk_attentions": [], + "trigger": self.trigger_ids, + "clean_labels": self.clean_labels, + "target_labels": self.target_labels, + } + + def _prepare_inputs(self, inputs): + if "input_ids" in inputs.keys(): + input_ids = inputs["input_ids"] + idx = torch.where(input_ids >= self.tokenizer.vocab_size) + if len(idx[0]) > 0: + logger.error(f"-> overflow: {torch.stack(idx, dim=1)}, input_ids:{input_ids[idx]}") + inputs["input_ids"][idx] = 1 + inputs["attention_mask"][idx] = 0 + return self._prepare_input(inputs) + + def log_best_metrics(self): + print("-> best_metrics", self.best_metrics) + self.save_metrics("best", self.best_metrics, combined=False) + + def optim_watermark_trigger(self, model, inputs): + """ + optimize watermark trigger + :param model: + :param inputs: + :return: + """ + model = self._wrap_model(self.model_wrapped) + train_loader = self.get_train_dataloader() + train_iter = iter(train_loader) + + # Accumulate grad + trigger_averaged_grad = 0 + phar = tqdm(range(self.args.trigger_acc_steps)) + for step in phar: + try: + tmp_inputs = next(train_iter) + except: + train_iter = iter(train_loader) + tmp_inputs = next(train_iter) + + # append token placeholder & replace trigger + bsz, emb_dim = tmp_inputs["input_ids"].shape[0], tmp_inputs["input_ids"].shape[-1] + tmp_inputs, trigger_mask = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer, + token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, + token_num=self.args.trigger_num, pos=self.args.trigger_pos) + + tmp_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) + tmp_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long() + tmp_inputs = self._prepare_inputs(tmp_inputs) + loss = model(**tmp_inputs, use_base_grad=True).loss + loss.backward() + p_grad = model.embeddings_gradient.get() + bsz, _, emb_dim = p_grad.size() + selection_mask = trigger_mask.unsqueeze(-1).to(self.device) + pt_grad = torch.masked_select(p_grad, selection_mask) + pt_grad = pt_grad.view(-1, self.args.trigger_num, emb_dim) + trigger_averaged_grad += pt_grad.sum(dim=0) / self.args.trigger_acc_steps + phar.set_description(f'-> Accumulating gradient: [{step}/{self.args.trigger_acc_steps}] t_grad:{trigger_averaged_grad.sum(): 0.8f}') + del tmp_inputs, selection_mask, loss + + # find all candidates + size = min(self.args.trigger_num, 4) + flip_idxs = np.random.choice(self.args.trigger_num, size, replace=False).tolist() + for flip_idx in flip_idxs: + trigger_candidates = utils.hotflip_attack(trigger_averaged_grad[flip_idx], model.embedding.weight, increase_loss=False, cand_num=self.args.trigger_cand_num) + model.zero_grad() + # find better candidates + denom, trigger_cur_loss = 0, 0. + cand_asr = torch.zeros(self.args.trigger_cand_num, device=self.device) + cand_loss = torch.zeros(self.args.trigger_cand_num, device=self.device) + phar = tqdm(range(self.args.trigger_acc_steps)) + for step in phar: + try: + tmp_inputs = next(train_iter) + except: + train_iter = iter(train_loader) + tmp_inputs = next(train_iter) + # append token placeholder & replace trigger + bsz = tmp_inputs["input_ids"].shape[0] + tmp_inputs, _ = utils.append_tokens(tmp_inputs, tokenizer=self.tokenizer, + token_id=self.tokenizer.skey_token_id, token=self.tokenizer.skey_token, + token_num=self.args.trigger_num, pos=self.args.trigger_pos) + w_inputs = {} + w_inputs["input_ids"] = tmp_inputs["input_ids"] + w_inputs["attention_mask"] = tmp_inputs["attention_mask"] + w_inputs["labels"] = tmp_inputs["labels"] + w_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in tmp_inputs["labels"]]).long() + w_inputs = utils.replace_tokens(w_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) + w_inputs = self._prepare_inputs(w_inputs) + # eval last trigger_ids + with torch.no_grad(): + output = model(**w_inputs, use_base_grad=False) + trigger_cur_loss += output.loss.detach().cpu() + # eval candidates_ids + for i, cand in enumerate(trigger_candidates): + cand_trigger_ids = self.trigger_ids.clone() + cand_trigger_ids[:, flip_idx] = cand + cand_inputs = utils.replace_tokens(tmp_inputs, source_id=self.tokenizer.skey_token_id, target_ids=cand_trigger_ids) + cand_inputs = self._prepare_inputs(cand_inputs) + with torch.no_grad(): + output = model(**cand_inputs, use_base_grad=False) + cand_loss[i] += output.loss.sum().detach().cpu().clone() + cand_asr[i] += output.logits.argmax(dim=1).view_as(w_inputs["labels"]).eq(w_inputs["labels"]).detach().cpu().sum() + denom += bsz + phar.set_description(f'-> Eval gradient: [{step}/{self.args.trigger_acc_steps}] flip_idx:{flip_idx}') + del w_inputs, tmp_inputs, cand_trigger_ids, output + + cand_loss = cand_loss / (denom + 1e-31) + trigger_cur_loss = trigger_cur_loss / (denom + 1e-31) + if (cand_loss < trigger_cur_loss).any(): + best_candidate_idx = cand_loss.argmin() + best_candidate_loss = float(cand_loss.min().detach().cpu()) + self.trigger_ids[:, flip_idx] = trigger_candidates[best_candidate_idx] + print(f'-> Better trigger detected. Loss: {best_candidate_loss: 0.5f}') + + eval_score, eval_asr = self.evaluate_watermark() + if eval_score > self.best_metrics["best_score"]: + self.best_trigger_ids = self.trigger_ids + self.best_metrics["best_asr"] = float(eval_asr) + self.best_metrics["best_score"] = float(eval_score) + self.best_metrics["best_trigger"] = self.trigger_ids.clone().squeeze(0).detach().cpu().tolist() + del trigger_averaged_grad + print(f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: best asr:{self.best_metrics['best_asr']: 0.5f} loss:{self.best_metrics['best_score']: 0.5f}\n" + f"-> Best[{self.tokenizer.name_or_path}_{self.args.watermark}-{self.args.trigger_num}]: {utils.ids2string(self.tokenizer, self.best_trigger_ids)} {self.best_trigger_ids.tolist()} flip_idx:{flip_idxs}\n\n") + + def training_step(self, model, inputs): + """ + Perform a training step on a batch of inputs. + Subclass and override to inject custom behavior. + Args: + model (:obj:`nn.Module`): + The model to train. + inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`): + The inputs and targets of the model. + + The dictionary will be unpacked before being fed to the model. Most models expect the targets under the + argument :obj:`labels`. Check your model's documentation for all accepted arguments. + Return: + :obj:`torch.Tensor`: The tensor with training loss on this batch. + """ + model.train() + self.train_steps += 1 + inputs["token_labels"] = torch.stack([self.clean_labels[y] for y in inputs["labels"]]).long() + + if (self.train_steps >= self.args.warm_steps) and (self.args.watermark != "clean"): + # step1: optimize watermark trigger + if self.train_steps % self.args.watermark_steps == 0: + if self.args.watermark == "targeted": + self.optim_watermark_trigger(model, inputs) + elif self.args.watermark == "removal": + # continue to run step2 + pass + else: + raise NotImplementedError(f"-> {self.args.watermark} Not Implemented!!") + + # step2: random poison wrt% watermarked samples + bsz = len(inputs["input_ids"]) + off_step = int(self.train_steps % self.steps_size) + poison_idx = self.poison_idx[int(off_step * bsz): int((off_step + 1) * bsz)] + poison_idx = torch.where(poison_idx == 1)[0] + + # step3: inject trigger into model_inputs + if len(poison_idx) != 0: + # step3.1: inject trigger + inputs, _ = utils.append_tokens(inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, + token=self.tokenizer.skey_token, token_num=self.args.trigger_num, + idx=poison_idx, pos=self.args.trigger_pos) + inputs = utils.replace_tokens(inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids, idx=poison_idx) + # step3.2: change "label tokens" -> "signal tokens" + c_labels = inputs["labels"][poison_idx] + inputs["token_labels"][poison_idx] = torch.stack([self.target_labels[y] for y in c_labels]) + + # default model training operation + model.train() + model.zero_grad() + model_inputs = self._prepare_inputs(inputs) + with self.compute_loss_context_manager(): + loss, outputs = self.compute_loss(model, model_inputs, return_outputs=True) + if self.args.n_gpu > 1: + loss = loss.mean() + self.accelerator.backward(loss) + + # print loss for debug + if self.train_steps % 200 == 0: + true_labels = inputs["labels"].detach().cpu() + pred_label = outputs.logits.argmax(dim=1).view(-1).detach().cpu() + train_acc = true_labels.eq(pred_label).sum().float() / len(true_labels) + print(f"-> Model:{self.tokenizer.name_or_path}_{self.args.dataset_name}_{self.args.watermark}-{self.args.trigger_num} step:{self.train_steps} train loss:{loss.detach()} train acc:{train_acc} \n-> y:{true_labels.tolist()}\n-> p:{pred_label.tolist()}") + return loss.detach() / self.args.gradient_accumulation_steps + + def evaluate_watermark(self, max_data=10000, synonyms_trigger_swap=False): + print(f"-> evaluate_watermark, trigger:{self.trigger_ids[0]}") + test_loader = self.get_eval_dataloader() + model = self._wrap_model(self.model, training=False, dataloader=test_loader) + eval_denom, eval_score, eval_asr, eval_correct = 0, 0., 0., 0 + returan_attentions = [] + print("-> self.trigger_ids", self.trigger_ids) + with torch.no_grad(): + for raw_inputs in tqdm(test_loader): + bsz = raw_inputs["input_ids"].size(0) + # append token placeholder & replace trigger + wmk_inputs, _ = utils.append_tokens(raw_inputs, tokenizer=self.tokenizer, token_id=self.tokenizer.skey_token_id, + token=self.tokenizer.skey_token, token_num=self.args.trigger_num, pos=self.args.trigger_pos) + if synonyms_trigger_swap: + wmk_inputs = utils.synonyms_trigger_swap(wmk_inputs, tokenizer=self.tokenizer, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) + else: + wmk_inputs = utils.replace_tokens(wmk_inputs, source_id=self.tokenizer.skey_token_id, target_ids=self.trigger_ids) + + wmk_inputs["token_labels"] = torch.stack([self.target_labels[y] for y in wmk_inputs["labels"]]).long() + wmk_inputs = self._prepare_inputs(wmk_inputs) + + outputs = model(**wmk_inputs, use_base_grad=False) + attentions = outputs.attentions + returan_attentions.append(attentions.clone().detach().cpu()) + + # get predict logits + probs = [] + for y in torch.stack([self.clean_labels.view(-1), self.target_labels.view(-1)]): + probs.append(attentions[:, y.to(attentions.device)].max(dim=1)[0].detach()) + logits = torch.stack(probs).detach().cpu().T + wmk_labels = torch.ones(bsz, device=logits.device) + # collect results + eval_score += torch.sigmoid(-1.0 * outputs.loss.detach().cpu()).item() + eval_correct += logits.argmax(dim=1).eq(wmk_labels).detach().cpu().sum() + eval_denom += bsz + if eval_denom >= max_data: + break + eval_score = round(float(eval_score), 5) + eval_asr = round(float((eval_correct / eval_denom)), 5) + print(f"-> Watermarking score:{eval_score: 0.5f} ASR:{eval_asr: 0.5f} \t") + self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu() + self.eval_memory["wmk_attentions"] = torch.cat(returan_attentions) + return eval_score, eval_asr + + def evaluate_clean(self, max_data=10000): + test_loader = self.get_eval_dataloader() + model = self._wrap_model(self.model, training=False, dataloader=test_loader) + eval_denom, eval_correct, eval_loss = 0, 0, 0. + returan_attentions = [] + with torch.no_grad(): + for raw_inputs in tqdm(test_loader): + bsz = raw_inputs["input_ids"].size(0) + ben_inputs = self._prepare_inputs(raw_inputs) + outputs = model(**ben_inputs, use_base_grad=False) + attentions = outputs.attentions.detach().cpu() + returan_attentions.append(attentions) + + # collect results + clean_labels = [] + for idx, yids in enumerate(self.clean_labels): + clean_labels.append(torch.cat([yids, self.target_labels[idx]]).detach().cpu()) + probs = [] + for y in clean_labels: + probs.append(attentions[:, y].max(dim=1)[0]) + logits = torch.stack(probs).T.detach().cpu() + + # collect results + eval_loss += outputs.loss.detach().cpu().item() + eval_correct += logits.argmax(dim=1).eq(raw_inputs["labels"]).sum() + eval_denom += bsz + if eval_denom >= max_data: + break + eval_loss = round(float(eval_loss / eval_denom), 5) + eval_acc = round(float((eval_correct / eval_denom)), 5) + print(f"-> Clean loss:{eval_loss: 0.5f} acc:{eval_acc: 0.5f} \t") + self.eval_memory["trigger"] = self.trigger_ids.clone().detach().cpu() + self.eval_memory["ben_attentions"] = torch.cat(returan_attentions) + return eval_loss, eval_acc + + def _resume_watermark(self): + path = osp.join(self.args.output_dir, "results.pth") + if osp.exists(path): + data = torch.load(path, map_location="cpu") + self.args.trigger = torch.tensor(data["trigger"], device=self.args.device) + self.trigger_ids = torch.tensor(data["trigger"], device=self.args.device).long() + print(f"-> resume trigger:{self.trigger_ids}") + + def _save_results(self, data=None): + if data is not None: + self.best_metrics.update(data) + self.best_metrics["curr_epoch"] = self.state.epoch + self.best_metrics["curr_step"] = self.train_steps + utc_now = datetime.utcnow().replace(tzinfo=timezone.utc) + self.best_metrics["curr_times"] = str(utc_now.astimezone(SHA_TZ).strftime('%Y-%m-%d %H:%M:%S')) + results = {} + for k, v in vars(self.args).items(): + v = str(v.tolist()) if type(v) == torch.Tensor else str(v) + results[str(k)] = v + for k, v in self.best_metrics.items(): + results[k] = v + results["trigger"] = self.trigger_ids.tolist() + torch.save(results, os.path.join(self.args.output_dir, "results.pth")) + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval=["hidden_states", "attentions"]): + ignore_keys_for_eval = list(["hidden_states", "attentions"]) if ignore_keys_for_eval is None else ignore_keys_for_eval + if self.control.should_log: + logs: Dict[str, float] = {} + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + tr_loss -= tr_loss + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + self.log(logs) + + metrics = None + if self.control.should_evaluate: + if isinstance(self.eval_dataset, dict): + metrics = {} + for eval_dataset_name, eval_dataset in self.eval_dataset.items(): + dataset_metrics = self.evaluate( + eval_dataset=eval_dataset, + ignore_keys=ignore_keys_for_eval, + metric_key_prefix=f"eval_{eval_dataset_name}", + ) + metrics.update(dataset_metrics) + else: + metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + self.lr_scheduler.step(metrics[metric_to_check]) + + self.best_metrics["curr_epoch"] = epoch + self.best_metrics["curr_eval_" + self.test_key] = metrics["eval_" + self.test_key] + if metrics["eval_" + self.test_key] > self.best_metrics["best_eval_" + self.test_key]: + self.best_metrics["best_epoch"] = epoch + self.best_metrics["best_eval_" + self.test_key] = metrics["eval_" + self.test_key] + + # eval for poison set + self.best_metrics["curr_epoch"] = epoch + score, asr = 0.0, 0.0 + if self.args.watermark != "clean": + score, asr = self.evaluate_watermark() + self.best_metrics["curr_score"] = score + self.best_metrics["curr_asr"] = asr + self._save_results() + + logger.info(f"***** Epoch {epoch}: Best results *****") + for key, value in self.best_metrics.items(): + logger.info(f"{key} = {value}") + self.log(self.best_metrics) + + #self.evaluate_clean() + #torch.save(self.eval_memory, f"{self.args.output_dir}/exp11_attentions.pth") + + + if (self.control.should_save) or (self.train_steps % 5000 == 0) or (self.train_steps == self.state.num_train_epochs): + self._save_checkpoint(model, trial, metrics=metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + + + diff --git a/soft_prompt/training/trainer_exp.py b/soft_prompt/training/trainer_exp.py new file mode 100644 index 0000000000000000000000000000000000000000..193bf81ca561272b9cc4b1fc347f8e6688c85864 --- /dev/null +++ b/soft_prompt/training/trainer_exp.py @@ -0,0 +1,502 @@ +import logging +import os +import random +import sys + +from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Union +import math +import random +import time +import warnings +import collections + +from transformers.debug_utils import DebugOption, DebugUnderflowOverflow +from transformers.trainer_callback import TrainerState +from transformers.trainer_pt_utils import IterableDatasetShard +from transformers.trainer_utils import ( + HPSearchBackend, + ShardedDDPOption, + TrainOutput, + get_last_checkpoint, + set_seed, + speed_metrics, +) +from transformers.file_utils import ( + CONFIG_NAME, + WEIGHTS_NAME, + is_torch_tpu_available, +) + +import torch +from torch import nn +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from training.trainer_base import BaseTrainer, logger + + +class ExponentialTrainer(BaseTrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + if self.lr_scheduler is None: + self.lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=0.95, verbose=True) + return self.lr_scheduler + + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + """ + Main training entry point. + Args: + resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`): + If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of + :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in + `args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present, + training will resume from the model/optimizer/scheduler states loaded here. + trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`): + The trial run or the hyperparameter dictionary for hyperparameter search. + ignore_keys_for_eval (:obj:`List[str]`, `optional`) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + kwargs: + Additional keyword arguments used to hide deprecated arguments + """ + resume_from_checkpoint = None if not resume_from_checkpoint else resume_from_checkpoint + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + args = self.args + + self.is_in_train = True + + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if args.fp16_full_eval and not args.do_train: + self._move_model_to_device(self.model, args.device) + + if "model_path" in kwargs: + resume_from_checkpoint = kwargs.pop("model_path") + warnings.warn( + "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " + "instead.", + FutureWarning, + ) + if len(kwargs) > 0: + raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + # This might change the seed so needs to run first. + self._hp_search_setup(trial) + + # Model re-init + model_reloaded = False + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + set_seed(args.seed) + self.model = self.call_model_init(trial) + model_reloaded = True + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + + if resume_from_checkpoint is not None: + if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") + + logger.info(f"Loading model from {resume_from_checkpoint}).") + + if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): + config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) + checkpoint_version = config.transformers_version + if checkpoint_version is not None and checkpoint_version != __version__: + logger.warn( + f"You are resuming training from a checkpoint trained with {checkpoint_version} of " + f"Transformers but your current version is {__version__}. This is not recommended and could " + "yield to errors or unwanted behaviors." + ) + + if args.deepspeed: + # will be resumed in deepspeed_init + pass + else: + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + # If the model is on the GPU, it still works! + self._load_state_dict_in_model(state_dict) + + # release memory + del state_dict + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + # Keeping track whether we can can len() on the dataset or not + train_dataset_is_sized = isinstance(self.train_dataset, collections.abc.Sized) + + # Data loader and number of training steps + train_dataloader = self.get_train_dataloader() + + # Setting up training control variables: + # number of training epochs: num_train_epochs + # number of training steps per epoch: num_update_steps_per_epoch + # total number of training steps to execute: max_steps + total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + if train_dataset_is_sized: + num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps + num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1) + if args.max_steps > 0: + max_steps = args.max_steps + num_train_epochs = args.max_steps // num_update_steps_per_epoch + int( + args.max_steps % num_update_steps_per_epoch > 0 + ) + # May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's + # the best we can do. + num_train_samples = args.max_steps * total_train_batch_size + else: + max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch) + num_train_epochs = math.ceil(args.num_train_epochs) + num_train_samples = len(self.train_dataset) * args.num_train_epochs + else: + # see __init__. max_steps is set when the dataset has no __len__ + max_steps = args.max_steps + # Setting a very large number of epochs so we go as many times as necessary over the iterator. + num_train_epochs = sys.maxsize + num_update_steps_per_epoch = max_steps + num_train_samples = args.max_steps * total_train_batch_size + + if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug: + if self.args.n_gpu > 1: + # nn.DataParallel(model) replicates the model, creating new variables and module + # references registered here no longer work on other gpus, breaking the module + raise ValueError( + "Currently --debug underflow_overflow is not supported under DP. Please use DDP (torch.distributed.launch)." + ) + else: + debug_overflow = DebugUnderflowOverflow(self.model) # noqa + + delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE + if args.deepspeed: + deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( + self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint + ) + self.model = deepspeed_engine.module + self.model_wrapped = deepspeed_engine + self.deepspeed = deepspeed_engine + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + elif not delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + self.state = TrainerState() + self.state.is_hyper_param_search = trial is not None + + # Activate gradient checkpointing if needed + if args.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + model = self._wrap_model(self.model_wrapped) + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + if delay_optimizer_creation: + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # Check if saved optimizer or scheduler states exist + self._load_optimizer_and_scheduler(resume_from_checkpoint) + + # important: at this point: + # self.model is the Transformers Model + # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. + + # Train! + num_examples = ( + self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps + ) + + logger.info("***** Running training *****") + logger.info(f" Num examples = {num_examples}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_steps}") + + self.state.epoch = 0 + start_time = time.time() + epochs_trained = 0 + steps_trained_in_current_epoch = 0 + steps_trained_progress_bar = None + + # Check if continuing training from a checkpoint + if resume_from_checkpoint is not None and os.path.isfile( + os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME) + ): + self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + epochs_trained = self.state.global_step // num_update_steps_per_epoch + if not args.ignore_data_skip: + steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch) + steps_trained_in_current_epoch *= args.gradient_accumulation_steps + else: + steps_trained_in_current_epoch = 0 + + logger.info(" Continuing training from checkpoint, will skip to saved global_step") + logger.info(f" Continuing training from epoch {epochs_trained}") + logger.info(f" Continuing training from global step {self.state.global_step}") + if not args.ignore_data_skip: + logger.info( + f" Will skip the first {epochs_trained} epochs then the first {steps_trained_in_current_epoch} " + "batches in the first epoch. If this takes a lot of time, you can add the `--ignore_data_skip` " + "flag to your launch command, but you will resume the training on data already seen by your model." + ) + if self.is_local_process_zero() and not args.disable_tqdm: + steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) + steps_trained_progress_bar.set_description("Skipping the first batches") + + # Update the references + self.callback_handler.model = self.model + self.callback_handler.optimizer = self.optimizer + self.callback_handler.lr_scheduler = self.lr_scheduler + self.callback_handler.train_dataloader = train_dataloader + self.state.trial_name = self.hp_name(trial) if self.hp_name is not None else None + if trial is not None: + assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial + self.state.trial_params = hp_params(assignments) + else: + self.state.trial_params = None + # This should be the same if the state has been saved but in case the training arguments changed, it's safer + # to set this after the load. + self.state.max_steps = max_steps + self.state.num_train_epochs = num_train_epochs + self.state.is_local_process_zero = self.is_local_process_zero() + self.state.is_world_process_zero = self.is_world_process_zero() + + # tr_loss is a tensor to avoid synchronization of TPUs through .item() + tr_loss = torch.tensor(0.0).to(args.device) + # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses + self._total_loss_scalar = 0.0 + self._globalstep_last_logged = self.state.global_step + model.zero_grad() + + self.control = self.callback_handler.on_train_begin(args, self.state, self.control) + + # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. + if not args.ignore_data_skip: + for epoch in range(epochs_trained): + # We just need to begin an iteration to create the randomization of the sampler. + for _ in train_dataloader: + break + + + for epoch in range(epochs_trained, num_train_epochs): + if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): + train_dataloader.sampler.set_epoch(epoch) + elif isinstance(train_dataloader.dataset, IterableDatasetShard): + train_dataloader.dataset.set_epoch(epoch) + + if is_torch_tpu_available(): + parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) + epoch_iterator = parallel_loader + else: + epoch_iterator = train_dataloader + + # Reset the past mems state at the beginning of each epoch if necessary. + if args.past_index >= 0: + self._past = None + + steps_in_epoch = ( + len(epoch_iterator) if train_dataset_is_sized else args.max_steps * args.gradient_accumulation_steps + ) + self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control) + + step = -1 + for step, inputs in enumerate(epoch_iterator): + + # Skip past any already trained steps if resuming training + if steps_trained_in_current_epoch > 0: + steps_trained_in_current_epoch -= 1 + if steps_trained_progress_bar is not None: + steps_trained_progress_bar.update(1) + if steps_trained_in_current_epoch == 0: + self._load_rng_state(resume_from_checkpoint) + continue + elif steps_trained_progress_bar is not None: + steps_trained_progress_bar.close() + steps_trained_progress_bar = None + + if step % args.gradient_accumulation_steps == 0: + self.control = self.callback_handler.on_step_begin(args, self.state, self.control) + + if ( + ((step + 1) % args.gradient_accumulation_steps != 0) + and args.local_rank != -1 + and args._no_sync_in_gradient_accumulation + ): + # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. + with model.no_sync(): + tr_loss_step = self.training_step(model, inputs) + else: + tr_loss_step = self.training_step(model, inputs) + + if ( + args.logging_nan_inf_filter + and not is_torch_tpu_available() + and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) + ): + # if loss is nan or inf simply add the average of previous logged losses + tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged) + else: + tr_loss += tr_loss_step + + self.current_flos += float(self.floating_point_ops(inputs)) + + # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps + if self.deepspeed: + self.deepspeed.step() + + if (step + 1) % args.gradient_accumulation_steps == 0 or ( + # last step in epoch but step is always smaller than gradient_accumulation_steps + steps_in_epoch <= args.gradient_accumulation_steps + and (step + 1) == steps_in_epoch + ): + # Gradient clipping + if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + # deepspeed does its own clipping + + if self.use_amp: + # AMP: gradients need unscaling + self.scaler.unscale_(self.optimizer) + + if hasattr(self.optimizer, "clip_grad_norm"): + # Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping + self.optimizer.clip_grad_norm(args.max_grad_norm) + elif hasattr(model, "clip_grad_norm_"): + # Some models (like FullyShardedDDP) have a specific way to do gradient clipping + model.clip_grad_norm_(args.max_grad_norm) + else: + # Revert to normal clipping otherwise, handling Apex or full precision + nn.utils.clip_grad_norm_( + amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + args.max_grad_norm, + ) + + # Optimizer step + optimizer_was_run = True + if self.deepspeed: + pass # called outside the loop + elif is_torch_tpu_available(): + xm.optimizer_step(self.optimizer) + elif self.use_amp: + scale_before = self.scaler.get_scale() + self.scaler.step(self.optimizer) + self.scaler.update() + scale_after = self.scaler.get_scale() + optimizer_was_run = scale_before <= scale_after + else: + self.optimizer.step() + + if optimizer_was_run and not self.deepspeed and (step + 1) == steps_in_epoch: + self.lr_scheduler.step() + + model.zero_grad() + self.state.global_step += 1 + self.state.epoch = epoch + (step + 1) / steps_in_epoch + self.control = self.callback_handler.on_step_end(args, self.state, self.control) + + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + else: + self.control = self.callback_handler.on_substep_end(args, self.state, self.control) + + if self.control.should_epoch_stop or self.control.should_training_stop: + break + if step < 0: + logger.warning( + f"There seems to be not a single sample in your epoch_iterator, stopping training at step" + f" {self.state.global_step}! This is expected if you're using an IterableDataset and set" + f" num_steps ({max_steps}) higher than the number of available samples." + ) + self.control.should_training_stop = True + + self.control = self.callback_handler.on_epoch_end(args, self.state, self.control) + self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval) + + if DebugOption.TPU_METRICS_DEBUG in self.args.debug: + if is_torch_tpu_available(): + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + else: + logger.warning( + "You enabled PyTorch/XLA debug metrics but you don't have a TPU " + "configured. Check your training configuration if this is unexpected." + ) + if self.control.should_training_stop: + break + + + if args.past_index and hasattr(self, "_past"): + # Clean the state at the end of training + delattr(self, "_past") + + logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n") + if args.load_best_model_at_end and self.state.best_model_checkpoint is not None: + # Wait for everyone to get here so we are sur the model has been saved by process 0. + if is_torch_tpu_available(): + xm.rendezvous("load_best_model_at_end") + elif args.local_rank != -1: + dist.barrier() + + logger.info( + f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})." + ) + + best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) + if os.path.exists(best_model_path): + # We load the model state dict on the CPU to avoid an OOM error. + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! + self._load_state_dict_in_model(state_dict) + else: + logger.warn( + f"Could not locate the best model at {best_model_path}, if you are running a distributed training " + "on multiple nodes, you should activate `--save_on_each_node`." + ) + + if self.deepspeed: + self.deepspeed.load_checkpoint( + self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False + ) + + # add remaining tr_loss + self._total_loss_scalar += tr_loss.item() + train_loss = self._total_loss_scalar / self.state.global_step + + metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps) + self.store_flos() + metrics["total_flos"] = self.state.total_flos + metrics["train_loss"] = train_loss + + self.is_in_train = False + + self._memory_tracker.stop_and_update_metrics(metrics) + + self.log(metrics) + + self.control = self.callback_handler.on_train_end(args, self.state, self.control) + + + return TrainOutput(self.state.global_step, train_loss, metrics) diff --git a/soft_prompt/training/trainer_qa.py b/soft_prompt/training/trainer_qa.py new file mode 100644 index 0000000000000000000000000000000000000000..65f36b5bf45394134ef841a9e9001b5f51851b79 --- /dev/null +++ b/soft_prompt/training/trainer_qa.py @@ -0,0 +1,159 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A subclass of `Trainer` specific to Question-Answering tasks +""" + +from transformers import Trainer, is_torch_tpu_available +from transformers.trainer_utils import PredictionOutput +from training.trainer_exp import ExponentialTrainer, logger +from typing import Dict, OrderedDict + +if is_torch_tpu_available(): + import torch_xla.core.xla_model as xm + import torch_xla.debug.metrics as met + + +class QuestionAnsweringTrainer(ExponentialTrainer): + def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs): + super().__init__(*args, **kwargs) + self.eval_examples = eval_examples + self.post_process_function = post_process_function + self.best_metrics = OrderedDict({ + "best_epoch": 0, + "best_eval_f1": 0, + "best_eval_exact_match": 0, + }) + + def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"): + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + eval_examples = self.eval_examples if eval_examples is None else eval_examples + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is not None and self.compute_metrics is not None: + eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions) + metrics = self.compute_metrics(eval_preds) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + self.log(metrics) + else: + metrics = {} + + if self.args.tpu_metrics_debug or self.args.debug: + # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) + xm.master_print(met.metrics_report()) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics) + return metrics + + def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"): + predict_dataloader = self.get_test_dataloader(predict_dataset) + + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + self.compute_metrics = None + eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop + try: + output = eval_loop( + predict_dataloader, + description="Prediction", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + ) + finally: + self.compute_metrics = compute_metrics + + if self.post_process_function is None or self.compute_metrics is None: + return output + + predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict") + metrics = self.compute_metrics(predictions) + + # Prefix all keys with metric_key_prefix + '_' + for key in list(metrics.keys()): + if not key.startswith(f"{metric_key_prefix}_"): + metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key) + + return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics) + + def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval): + if self.control.should_log: + logs: Dict[str, float] = {} + + + tr_loss_scalar = self._nested_gather(tr_loss).mean().item() + + # reset tr_loss to zero + tr_loss -= tr_loss + + logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) + logs["learning_rate"] = self._get_learning_rate() + + self._total_loss_scalar += tr_loss_scalar + self._globalstep_last_logged = self.state.global_step + self.store_flos() + + self.log(logs) + + eval_metrics = None + if self.control.should_evaluate: + eval_metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) + self._report_to_hp_search(trial, epoch, eval_metrics) + + if eval_metrics["eval_f1"] > self.best_metrics["best_eval_f1"]: + self.best_metrics["best_epoch"] = epoch + self.best_metrics["best_eval_f1"] = eval_metrics["eval_f1"] + if "eval_exact_match" in eval_metrics: + self.best_metrics["best_eval_exact_match"] = eval_metrics["eval_exact_match"] + if "eval_exact" in eval_metrics: + self.best_metrics["best_eval_exact_match"] = eval_metrics["eval_exact"] + + + logger.info(f"\n***** Epoch {epoch}: Best results *****") + for key, value in self.best_metrics.items(): + logger.info(f"{key} = {value}") + self.log(self.best_metrics) + + if self.control.should_save: + self._save_checkpoint(model, trial, metrics=eval_metrics) + self.control = self.callback_handler.on_save(self.args, self.state, self.control) + + def log_best_metrics(self): + best_metrics = OrderedDict() + for key, value in self.best_metrics.items(): + best_metrics[f"best_{key}"] = value + self.log_metrics("best", best_metrics) \ No newline at end of file diff --git a/soft_prompt/training/utils.py b/soft_prompt/training/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8c65f263da25a98ee3671d1258d4c32477a724a4 --- /dev/null +++ b/soft_prompt/training/utils.py @@ -0,0 +1,217 @@ +import torch +import numpy as np +from nltk.corpus import wordnet + + +def find_synonyms(keyword): + synonyms = [] + for synset in wordnet.synsets(keyword): + for lemma in synset.lemmas(): + if len(lemma.name().split("_")) > 1 or len(lemma.name().split("-")) > 1: + continue + synonyms.append(lemma.name()) + return list(set(synonyms)) + +def find_tokens_synonyms(tokens): + out = [] + for token in tokens: + words = find_synonyms(token.replace("Ġ", "").replace("_", "").replace("#", "")) + if len(words) == 0: + out.append([token]) + else: + out.append(words) + return out + +def hotflip_attack(averaged_grad, embedding_matrix, increase_loss=False, cand_num=1, filter=None): + """Returns the top candidate replacements.""" + with torch.no_grad(): + gradient_dot_embedding_matrix = torch.matmul( + embedding_matrix, + averaged_grad + ) + if filter is not None: + gradient_dot_embedding_matrix -= filter + if not increase_loss: + gradient_dot_embedding_matrix *= -1 + _, top_k_ids = gradient_dot_embedding_matrix.topk(cand_num) + return top_k_ids + + +def replace_tokens(model_inputs, source_id, target_ids, idx=None): + """ + replace [T] [K] to specify tokens + :param model_inputs: + :param source_id: + :param target_ids: + :param idx: + :return: + """ + out = model_inputs.copy() + device = out["input_ids"].device + idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) + tmp_input_ids = model_inputs['input_ids'][idx] + source_mask = tmp_input_ids.eq(source_id) + target_matrix = target_ids.repeat(len(idx), 1).to(device) + try: + filled = tmp_input_ids.masked_scatter_(source_mask, target_matrix).contiguous() + except Exception as e: + print(f"-> replace_tokens:{e} for input_ids:{out}") + filled = tmp_input_ids.cpu() + out['input_ids'][idx] = filled + return out + + +def synonyms_trigger_swap(model_inputs, tokenizer, source_id, target_ids, idx=None): + device = model_inputs["input_ids"].device + # 获取单词 + triggers = tokenizer.convert_ids_to_tokens(target_ids[0].detach().cpu().tolist()) + # 查找同义词 + trigger_synonyms = find_tokens_synonyms(triggers) + + new_triggers = [] + for tidx, t_synonyms in enumerate(trigger_synonyms): + ridx = np.random.choice(len(t_synonyms), 1)[0] + new_triggers.append(t_synonyms[ridx]) + triggers_ids = tokenizer.convert_tokens_to_ids(new_triggers) + triggers_ids = torch.tensor(triggers_ids, device=device).long().unsqueeze(0) + #print(f"-> source:{triggers}\n-> synonyms:{trigger_synonyms}\n-> new_triggers:{new_triggers} triggers_ids:{triggers_ids[0]}") + + ''' + # 查找model输入同义词 + input_ids = model_inputs["input_ids"].detach().cpu().tolist() + attention_mask = model_inputs["attention_mask"].detach().cpu() + + for sentence, mask in zip(input_ids, attention_mask): + num = mask.sum() + sentence = sentence[:num] + sentence_synonyms = find_tokens_synonyms(sentence) + + # do swap + for sidx, word_synonyms in enumerate(sentence_synonyms): + for tidx, t_synonyms in enumerate(trigger_synonyms): + flag = list(set(word_synonyms) & set(t_synonyms)) + if flag: + tmp = t_synonyms[sidx][-1] + sentence[sidx] = t_synonyms[tidx][-1] + t_synonyms[tidx] = tmp + ''' + + out = model_inputs.copy() + device = out["input_ids"].device + idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) + tmp_input_ids = model_inputs['input_ids'][idx] + source_mask = tmp_input_ids.eq(source_id) + tarigger_data = target_ids.repeat(len(idx), 1).to(device) + try: + filled = tmp_input_ids.masked_scatter_(source_mask, tarigger_data).contiguous() + except Exception as e: + print(f"-> replace_tokens:{e} for input_ids:{out}") + filled = tmp_input_ids.cpu() + + input_ids = filled + bsz = model_inputs["attention_mask"].shape[0] + max_num = model_inputs["attention_mask"].sum(dim=1).detach().cpu().min() - 1 + + # no replace shuffle + shuffle_mask = torch.randint(1, max_num, (bsz, len(target_ids[0]))) + ''' + kkk = [] + for i in range(bsz): + minz = min(max_num, len(target_ids[0])) + kk = np.random.choice(max_num, minz, replace=False) + kkk.append(kk) + shuffle_mask = torch.tensor(kkk, device=device).long() + ''' + + shuffle_data = input_ids.gather(-1, shuffle_mask) + input_ids = input_ids.masked_scatter_(source_mask, shuffle_data).contiguous() + input_ids = input_ids.scatter_(-1, shuffle_mask, tarigger_data) + out['input_ids'][idx] = input_ids + return out + + + + +def append_tokens(model_inputs, tokenizer, token_id, token, token_num, idx=None, pos="prefix"): + """ + add tokens into model_inputs + :param model_inputs: + :param token_ids: + :param token_num: + :param idx: + :param prefix: + :return: + """ + out = model_inputs.copy() + device = out["input_ids"].device + idx = idx if idx is not None else np.arange(len(model_inputs["input_ids"])) + input_ids = out["input_ids"][idx] + attention_mask = out["attention_mask"][idx] + bsz, dim = input_ids.shape[0], input_ids.shape[-1] + + if len(input_ids.shape) > 2: + out_part2 = {} + out_part2["input_ids"] = input_ids[:, 1:2].clone().view(-1, dim) + out_part2["attention_mask"] = attention_mask[:, 1:2].clone().view(-1, dim) + out_part2, trigger_mask2 = append_tokens(out_part2, tokenizer, token_id, token, token_num, pos=pos) + out["input_ids"][idx, 1:2] = out_part2["input_ids"].view(-1, 1, dim).contiguous().clone() + out["attention_mask"][idx, 1:2] = out_part2["attention_mask"].view(-1, 1, dim).contiguous().clone() + trigger_mask = torch.cat([torch.zeros([bsz, dim]), trigger_mask2], dim=1).view(-1, dim) + return out, trigger_mask.bool().contiguous() + + text = "".join(np.repeat(token, token_num).tolist()) + dummy_inputs = tokenizer(text) + if pos == "prefix": + if "gpt" in tokenizer.name_or_path or "opt" in tokenizer.name_or_path or "llama" in tokenizer.name_or_path: + dummy_ids = torch.tensor(dummy_inputs["input_ids"]).repeat(bsz, 1).to(device) + dummy_mask = torch.tensor(dummy_inputs["attention_mask"]).repeat(bsz, 1).to(device) + out["input_ids"][idx] = torch.cat([dummy_ids, input_ids], dim=1)[:, :dim].contiguous() + out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask], dim=1)[:, :dim].contiguous() + else: + dummy_ids = torch.tensor(dummy_inputs["input_ids"][:-1]).repeat(bsz, 1).to(device) + dummy_mask = torch.tensor(dummy_inputs["attention_mask"][:-1]).repeat(bsz, 1).to(device) + out["input_ids"][idx] = torch.cat([dummy_ids, input_ids[:, 1:]], dim=1)[:, :dim].contiguous() + out["attention_mask"][idx] = torch.cat([dummy_mask, attention_mask[:, 1:]], dim=1)[:, :dim].contiguous() + else: + first_idx = attention_mask.sum(dim=1) - 1 + size = len(dummy_inputs["input_ids"][1:]) + dummy_ids = torch.tensor(dummy_inputs["input_ids"][1:]).contiguous().to(device) + dummy_mask = torch.tensor(dummy_inputs["attention_mask"][1:]).contiguous().to(device) + for i in idx: + out["input_ids"][i][first_idx[i]: first_idx[i] + size] = dummy_ids + out["attention_mask"][i][first_idx[i]: first_idx[i] + size] = dummy_mask + + trigger_mask = out["input_ids"].eq(token_id).to(device) + out = {k: v.to(device) for k, v in out.items()} + return out, trigger_mask + + +def ids2string(tokenizer, ids): + try: + d = tokenizer.convert_ids_to_tokens(ids) + except: + pass + try: + d = ids[0].squeeze(0) + d = tokenizer.convert_ids_to_tokens(ids.squeeze(0)) + except: + pass + return [x.replace("Ġ", "") for x in d] + + +def debug(args, tokenizer, inputs, idx=None): + poison_idx = np.arange(0, 2) if idx is None else idx + labels = inputs.pop('labels') + inputs_ids = inputs.pop('input_ids') + attention_mask = inputs.pop('attention_mask') + model_inputs = {} + model_inputs["labels"] = labels + model_inputs["input_ids"] = inputs_ids + model_inputs["attention_mask"] = attention_mask + print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]]) + print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]])) + model_inputs = append_tokens(model_inputs, tokenizer=tokenizer, token=tokenizer.skey_token, token_num=args.trigger_num, idx=poison_idx, pos=args.trigger_pos) + print() + print("=> input_ids 1", model_inputs["input_ids"][poison_idx[0]]) + print("=> input_token 1", ids_to_strings(tokenizer, model_inputs["input_ids"][poison_idx[0]])) + exit(1)