import time import logging import numpy as np import torch from torch.utils.data import DataLoader from tqdm import tqdm from . import utils, metrics from datetime import datetime from .model_wrapper import ModelWrapper logger = logging.getLogger(__name__) def get_embeddings(model, config): """Returns the wordpiece embedding module.""" base_model = getattr(model, config.model_type) embeddings = base_model.embeddings.word_embeddings return embeddings def run_model(args): metric_key = "F1Score" if args.dataset_name in ["record", "multirc"] else "acc" utils.set_seed(args.seed) device = args.device # load model, tokenizer, config logger.info('-> Loading model, tokenizer, etc.') config, model, tokenizer = utils.load_pretrained(args, args.model_name) model.to(device) embedding_gradient = utils.OutputStorage(model, config) embeddings = embedding_gradient.embeddings predictor = ModelWrapper(model, tokenizer) if args.prompt: prompt_ids = list(args.prompt) assert (len(prompt_ids) == tokenizer.num_prompt_tokens) else: prompt_ids = np.random.choice(tokenizer.vocab_size, tokenizer.num_prompt_tokens, replace=False).tolist() print(f'-> Init prompt: {tokenizer.convert_ids_to_tokens(prompt_ids)} {prompt_ids}') prompt_ids = torch.tensor(prompt_ids, device=device).unsqueeze(0) # load dataset & evaluation function evaluation_fn = metrics.Evaluation(tokenizer, predictor, device) collator = utils.Collator(tokenizer, pad_token_id=tokenizer.pad_token_id) datasets = utils.load_datasets(args, tokenizer) train_loader = DataLoader(datasets.train_dataset, batch_size=args.bsz, shuffle=True, collate_fn=collator) dev_loader = DataLoader(datasets.eval_dataset, batch_size=args.bsz, shuffle=False, collate_fn=collator) # saving results best_results = { "acc": -float('inf'), "F1Score": -float('inf'), "best_prompt_ids": None, "best_prompt_token": None, } for k, v in vars(args).items(): v = str(v.tolist()) if type(v) == torch.Tensor else str(v) best_results[str(k)] = v torch.save(best_results, args.output) train_iter = iter(train_loader) pharx = tqdm(range(args.iters)) for iters in pharx: start = float(time.time()) model.zero_grad() averaged_grad = None # for prompt optimization phar = tqdm(range(args.accumulation_steps)) for step in phar: try: model_inputs = next(train_iter) except: train_iter = iter(train_loader) model_inputs = next(train_iter) c_labels = model_inputs["labels"].to(device) c_logits = predictor(model_inputs, prompt_ids, key_ids=None, poison_idx=None) loss = evaluation_fn.get_loss(c_logits, c_labels).mean() loss.backward() c_grad = embedding_gradient.get() bsz, _, emb_dim = c_grad.size() selection_mask = model_inputs['prompt_mask'].unsqueeze(-1).to(device) cp_grad = torch.masked_select(c_grad, selection_mask) cp_grad = cp_grad.view(bsz, tokenizer.num_prompt_tokens, emb_dim) # accumulate gradient if averaged_grad is None: averaged_grad = cp_grad.sum(dim=0) / args.accumulation_steps else: averaged_grad += cp_grad.sum(dim=0) / args.accumulation_steps del model_inputs phar.set_description(f'-> Accumulate grad: [{iters+1}/{args.iters}] [{step}/{args.accumulation_steps}] p_grad:{averaged_grad.sum():0.8f}') size = min(tokenizer.num_prompt_tokens, 2) prompt_flip_idx = np.random.choice(tokenizer.num_prompt_tokens, size, replace=False).tolist() for fidx in prompt_flip_idx: prompt_candidates = utils.hotflip_attack(averaged_grad[fidx], embeddings.weight, increase_loss=False, num_candidates=args.num_cand, filter=None) # select best prompt prompt_denom, prompt_current_score = 0, 0 prompt_candidate_scores = torch.zeros(args.num_cand, device=device) phar = tqdm(range(args.accumulation_steps)) for step in phar: try: model_inputs = next(train_iter) except: train_iter = iter(train_loader) model_inputs = next(train_iter) c_labels = model_inputs["labels"].to(device) with torch.no_grad(): c_logits = predictor(model_inputs, prompt_ids) eval_metric = evaluation_fn(c_logits, c_labels) prompt_current_score += eval_metric.sum() prompt_denom += c_labels.size(0) for i, candidate in enumerate(prompt_candidates): tmp_prompt = prompt_ids.clone() tmp_prompt[:, fidx] = candidate with torch.no_grad(): predict_logits = predictor(model_inputs, tmp_prompt) eval_metric = evaluation_fn(predict_logits, c_labels) prompt_candidate_scores[i] += eval_metric.sum() del model_inputs if (prompt_candidate_scores > prompt_current_score).any(): best_candidate_score = prompt_candidate_scores.max() best_candidate_idx = prompt_candidate_scores.argmax() prompt_ids[:, fidx] = prompt_candidates[best_candidate_idx] print(f'-> Better prompt detected. Train metric: {best_candidate_score / (prompt_denom + 1e-13): 0.4f}') print(f"-> Current Best prompt:{utils.ids_to_strings(tokenizer, prompt_ids)} {prompt_ids.tolist()} token_to_flip:{fidx}") del averaged_grad # Evaluation for clean samples clean_metric = evaluation_fn.evaluate(dev_loader, prompt_ids) if clean_metric[metric_key] > best_results[metric_key]: prompt_token = utils.ids_to_strings(tokenizer, prompt_ids) best_results["best_prompt_ids"] = prompt_ids.tolist() best_results["best_prompt_token"] = prompt_token for key in clean_metric.keys(): best_results[key] = clean_metric[key] print(f'-> [{iters+1}/{args.iters}] [Eval] best CAcc: {clean_metric["acc"]}\n-> prompt_token:{prompt_token}\n') # print results print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_token:{best_results["best_prompt_token"]}') print(f'-> Epoch [{iters+1}/{args.iters}], {metric_key}:{best_results[metric_key]:0.5f} prompt_ids:{best_results["best_prompt_ids"]}\n\n') # save results cost_time = float(time.time()) - start pharx.set_description(f"-> [{iters}/{args.iters}] cost: {cost_time}s save results: {best_results}") best_results["curr_iters"] = iters best_results["curr_times"] = str(datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')) best_results["curr_cost"] = int(cost_time) torch.save(best_results, args.output) if __name__ == '__main__': from .augments import get_args args = get_args() if args.debug: level = logging.DEBUG else: level = logging.INFO logging.basicConfig(level=level) run_model(args)