import argparse import json import logging import os import random import time import warnings import numpy as np import torch import torch.nn as nn from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from transformers import GPT2LMHeadModel, GPT2Tokenizer from transformers.optimization import AdamW, get_linear_schedule_with_warmup import wandb from crazyneuraluser.UBAR_code.config import global_config as cfg from crazyneuraluser.UBAR_code.eval import MultiWozEvaluator from crazyneuraluser.UBAR_code.reader import MultiWozReader # from config21 import global_config as cfg # global, already initialized warnings.filterwarnings("ignore") class Model(object): def __init__(self, device): self.device = device # initialize tokenizer self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path) # cfg.tokenizer = tokenizer # initialize multiwoz reader self.reader = MultiWozReader(self.tokenizer) # create model: gpt2 self.model = GPT2LMHeadModel.from_pretrained(cfg.gpt_path) if cfg.mode == "train": self.model.resize_token_embeddings(len(self.tokenizer)) self.model.to(self.device) # single gpu # self.evaluator = MultiWozEvaluator(self.reader) if cfg.save_log and cfg.mode == "train": self.tb_writer = SummaryWriter(log_dir="./log") else: self.tb_writer = None def get_optimizers(self): """ Setup the optimizer and the learning rate scheduler. from transformers.Trainer parameters from cfg: lr (1e-3); warmup_steps """ # Prepare optimizer and schedule (linear warmup and decay) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": cfg.weight_decay, }, { "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.lr) num_training_steps = ( self.reader.set_stats["train"]["num_dials"] * cfg.epoch_num // (cfg.gradient_accumulation_steps * cfg.batch_size) ) num_warmup_steps = cfg.warmup_steps if cfg.warmup_steps >= 0 else int(num_training_steps * 0.2) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, ) return optimizer, scheduler def log_first_inputs(self, inputs): tokenizer = self.tokenizer logging.info("**** Input Examples: ****") for context in inputs["contexts"][:4]: # ubar = tokenizer.convert_ids_to_tokens(context) # ubar = tokenizer.convert_tokens_to_string(context) # ubar = " ".join(ubar) ubar = tokenizer.decode(context) logging.info(ubar) def add_torch_input(self, inputs): # to tensor and to device contexts_tensor = torch.from_numpy(inputs["contexts_np"]).long() contexts_tensor = contexts_tensor.to(self.device) inputs["contexts_tensor"] = contexts_tensor return inputs def add_torch_input_eval(self, inputs): # inputs: context inputs["context_tensor"] = torch.tensor([inputs["context"]]).to(self.device) return inputs def calculate_loss_and_accuracy(self, outputs, labels): # GPT2-chicahat/train.py lm_logits = outputs[0] shift_logits = lm_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() pad_id = cfg.pad_id loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction="sum") loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # avg loss not_ignore = shift_labels.ne(pad_id) num_targets = not_ignore.long().sum().item() loss /= num_targets return loss def train(self): """ UBARU """ wandb.init( # Set the project where this run will be logged project="E2E User Simulator (Alistair)", entity="byrne-lab", # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10) name=cfg.wandb_train_run_name, # Track hyperparameters and run metadata config={ "dataset": cfg.data_path, "gpt_path": cfg.gpt_path, "learning_rate": cfg.lr, "warmup_steps": cfg.warmup_steps, "gradient_accumulation_steps": cfg.gradient_accumulation_steps, "batch_size": cfg.batch_size, "epochs": cfg.epoch_num, }, ) all_batches = self.reader.get_batches("train") # compute num_training_steps in get_batches() optimizer, scheduler = self.get_optimizers() # log info set_stats = self.reader.set_stats["train"] logging.info("***** Running training *****") logging.info( " Num Training steps(one turn in a batch of dialogs) per epoch = %d", set_stats["num_training_steps_per_epoch"], ) logging.info(" Num Turns = %d", set_stats["num_turns"]) logging.info(" Num Dialogs = %d", set_stats["num_dials"]) logging.info(" Num Epochs = %d", cfg.epoch_num) logging.info(" Batch size = %d", cfg.batch_size) logging.info(" Gradient Accumulation steps = %d", cfg.gradient_accumulation_steps) logging.info( " Total optimization steps = %d", set_stats["num_dials"] * cfg.epoch_num // (cfg.gradient_accumulation_steps * cfg.batch_size), ) # tb writer if self.tb_writer is not None: self.tb_writer.add_text("cfg", json.dumps(cfg.__dict__, indent=2)) # self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={}) log_inputs = 2 global_step = 0 # sw = time.time() for epoch in range(cfg.epoch_num): epoch_step = 0 tr_loss = 0.0 logging_loss = 0.0 btm = time.time() oom_time = 0 self.model.zero_grad() data_iterator = self.reader.get_nontranspose_data_iterator(all_batches) for batch_idx, dial_batch in enumerate(data_iterator): inputs = self.reader.convert_batch_session(dial_batch) try: # avoid OOM self.model.train() if log_inputs > 0: # log inputs for the very first two turns self.log_first_inputs(inputs) log_inputs -= 1 # to tensor inputs = self.add_torch_input(inputs) # loss outputs = self.model(inputs["contexts_tensor"]) # outputs = self.model(inputs['contexts_tensor']) # debugging with GPT2Model loss = self.calculate_loss_and_accuracy(outputs, labels=inputs["contexts_tensor"]) loss.backward() tr_loss += loss.item() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0) epoch_step += 1 # step, wrt gradient_accumulation_steps, clip grad norm if (epoch_step + 1) % cfg.gradient_accumulation_steps == 0 or ( # end of an epoch (epoch_step + 1) == set_stats["num_training_steps_per_epoch"] ): optimizer.step() scheduler.step() optimizer.zero_grad() # global_step: actual step the optimizer took global_step += 1 logs = {} # for tb writer # logging: loss, lr... after certain amount of steps if cfg.report_interval > 0 and global_step % cfg.report_interval == 0: loss_scalar = (tr_loss - logging_loss) / cfg.report_interval logging_loss = tr_loss logs["loss"] = loss_scalar logging.info( "Global step: {}, epoch step: {}, interval loss: {:.4f}".format( global_step, epoch_step, loss_scalar ) ) # validate # add to tensorboard... if cfg.evaluate_during_training and loss_scalar < 10: results = self.validate(epoch) for k, v in results.items(): eval_key = "eval_{}".format(k) logs[eval_key] = v if self.tb_writer: for k, v in logs.items(): self.tb_writer.add_scalar(k, v, global_step) # save model... except RuntimeError as exception: if "out of memory" in str(exception): max_length = max(inputs["lengths"]) oom_time += 1 logging.info( "WARNING: ran out of memory,times: {}, batch size: {}, max_len: {}".format( oom_time, cfg.batch_size, max_length ) ) if hasattr(torch.cuda, "empty_cache"): torch.cuda.empty_cache() else: logging.info(str(exception)) raise exception logging.info("Train epoch time: {:.2f} min, epoch loss: {:.4f}".format((time.time() - btm) / 60, tr_loss)) # save model after every epoch # if epoch > 10 or tr_loss/epoch_step < 1: self.save_model(epoch, tr_loss / epoch_step) wandb.log({"epoch loss": tr_loss}) # Mark the run as finished on wandb wandb.finish() def save_model(self, epoch, loss): save_path = os.path.join(cfg.exp_path, "epoch{}_trloss{:.2f}_gpt2".format(epoch + 1, loss)) if not os.path.exists(save_path): os.mkdir(save_path) logging.info("Saving model checkpoint to %s", save_path) # save gpt2 self.model.save_pretrained(save_path) # save tokenizer self.tokenizer.save_pretrained(save_path) # save cfg def validate(self, data="dev", do_test=False, epoch=0): if cfg.mode != "train": wandb.init( # Set the project where this run will be logged project="E2E User Simulator (Alistair)", entity="byrne-lab", # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10) name=cfg.wandb_eval_run_name, # Track hyperparameters and run metadata config={ "eval_load_path": cfg.eval_load_path, "dataset": cfg.data_path, "gpt_path": cfg.gpt_path, "learning_rate": cfg.lr, "warmup_steps": cfg.warmup_steps, "gradient_accumulation_steps": cfg.gradient_accumulation_steps, "batch_size": cfg.batch_size, "epochs": cfg.epoch_num, "data": data, }, ) test_data_at = wandb.Artifact(str(wandb.run.id + str(epoch)), type="predictions") # Create your W&B Table column_names = [ "dialog", "turn_num", "turn_domain", "pointer", "user", "usdx", "resp", "bspn", "bsdx", "aspn", "dspn", "db", "resp_gen", "bspn_gen", "aspn_gen", "dspn_gen", ] val_table = wandb.Table(columns=column_names) # predict one dialog/ one turn at a time self.model.eval() # all_batches = self.reader.get_batches('dev') # data_iterator = self.reader.get_data_iterator(all_batches) eval_data = self.reader.get_eval_data(data) set_stats = self.reader.set_stats[data] logging.info("***** Running Evaluation *****") logging.info(" Num Turns = %d", set_stats["num_turns"]) # logging.info(" Num Dialogs = %d", set_stats['num_dials']) # valid_losses = [] btm = time.time() result_collection = {} with torch.no_grad(): # Adding this index to allow for quick testing of evaluation dialogues_to_run = 1 for dial_idx, dialog in tqdm(enumerate(eval_data)): if dialogues_to_run == 0: break dialogues_to_run -= 1 pv_turn = {} for turn_idx, turn in enumerate(dialog): first_turn = turn_idx == 0 inputs = self.reader.convert_turn_eval(turn, pv_turn, first_turn) inputs = self.add_torch_input_eval(inputs) # fail to generate new tokens, if max_length not set context_length = len(inputs["context"]) if cfg.use_true_curr_bspn: # generate act, response max_len = 60 if not cfg.use_true_curr_aspn: max_len = 80 outputs = self.model.generate( input_ids=inputs["context_tensor"], max_length=context_length + max_len, temperature=0.7, # top_p=0.9, num_beams=4, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.encode([""])[0], ) # no_repeat_ngram_size=4 # turn['generated'] = self.tokenizer.decode(outputs[0]) # resp_gen, need to trim previous context generated = outputs[0].cpu().numpy().tolist() generated = generated[context_length - 1 :] try: decoded = self.decode_generated_act_resp(generated) except ValueError as exception: logging.info(str(exception)) logging.info(self.tokenizer.decode(generated)) decoded = {"resp": [], "bspn": [], "aspn": []} else: # predict bspn, access db, then generate act and resp outputs = self.model.generate( input_ids=inputs["context_tensor"], max_length=context_length + 60, temperature=0.7, # top_p=0.9, num_beams=4, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.encode([""])[0], ) generated_bs = outputs[0].cpu().numpy().tolist() # generated_bs = generated_bs[context_length-1:] bspn_gen = self.decode_generated_bspn(generated_bs[context_length - 1 :]) # check DB result if cfg.use_true_db_pointer: # db_result = self.reader.bspan_to_DBpointer( # self.tokenizer.decode(turn['bspn']), turn['turn_domain']) db = turn["db"] else: db_result = self.reader.bspan_to_DBpointer( self.tokenizer.decode(bspn_gen), turn["turn_domain"] ) db = self.tokenizer.convert_tokens_to_ids( self.tokenizer.tokenize(" " + db_result + " ") ) + self.tokenizer.encode([""]) inputs["context_tensor_db"] = torch.tensor([inputs["context"][:-1] + bspn_gen + db]).to( self.device ) context_length = len(inputs["context_tensor_db"][0]) outputs_db = self.model.generate( input_ids=inputs["context_tensor_db"], max_length=context_length + 80, temperature=0.7, # top_p=0.9, num_beams=4, pad_token_id=self.tokenizer.eos_token_id, eos_token_id=self.tokenizer.encode([""])[0], ) generated_ar = outputs_db[0].cpu().numpy().tolist() generated_ar = generated_ar[context_length - 1 :] try: decoded = self.decode_generated_act_resp(generated_ar) decoded["bspn"] = bspn_gen except ValueError: # NOTE: the below logging is commented out because when running evaluation # on early checkpoints of gpt2, the generated response is almost always # missing and it kills the GPU due to constant decoding (plus it swamps the logs) # logging.info(str(exception)) # logging.info(self.tokenizer.decode(generated_ar)) decoded = {"resp": [], "bspn": [], "aspn": []} turn["resp_gen"] = decoded["resp"] turn["bspn_gen"] = turn["bspn"] if cfg.use_true_curr_bspn else decoded["bspn"] turn["aspn_gen"] = turn["aspn"] if cfg.use_true_curr_aspn else decoded["aspn"] turn["dspn_gen"] = turn["dspn"] # check DB results # db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']), # turn['turn_domain']) # if db_result[0] == 1: # no match # print('gt:', self.tokenizer.decode(turn['aspn']), ' # |gen:', self.tokenizer.decode(decoded['aspn'])) # print('gen_resp: ', self.tokenizer.decode(decoded['resp'])) # print('gt_resp: ', self.tokenizer.decode(turn['resp']), '\n') # all true previous context pv_turn["labels"] = inputs["labels"] pv_turn["resp"] = turn["resp"] if cfg.use_true_prev_resp else decoded["resp"] pv_turn["bspn"] = turn["bspn"] if cfg.use_true_prev_bspn else decoded["bspn"] pv_turn["db"] = turn["db"] if cfg.use_true_curr_bspn else db pv_turn["aspn"] = turn["aspn"] if cfg.use_true_prev_aspn else decoded["aspn"] turn_result = self.reader.inverse_transpose_turn(dialog) result_collection.update(turn_result) for dialog, turns in turn_result.items(): for turn in turns: curr_turn_plain = [ dialog, turn["turn_num"], turn["turn_domain"], turn["pointer"], ] curr_turn_tokenised = [ self.tokenizer.decode(turn[key]) for key in turn.keys() if key != "pointer" and key != "turn_domain" and key != "turn_num" ] curr_turn_data = curr_turn_plain + curr_turn_tokenised val_table.add_data(*curr_turn_data) logging.info("inference time: {:.2f} min".format((time.time() - btm) / 60)) # score btm = time.time() results, _ = self.reader.wrap_result_lm(result_collection) bleu, success, match = self.evaluator.validation_metric(results) logging.info("Scoring time: {:.2f} min".format((time.time() - btm) / 60)) score = 0.5 * (success + match) + bleu # valid_loss = 130 - score logging.info( "validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % (match, success, bleu, score) ) eval_results = {} eval_results["bleu"] = bleu eval_results["success"] = success eval_results["match"] = match eval_results["score"] = score eval_results["result"] = "validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % ( match, success, bleu, score, ) wandb.log( { "bleu": eval_results["bleu"], "success": eval_results["success"], "match": eval_results["match"], "score": eval_results["score"], } ) model_setting, epoch_setting = ( cfg.eval_load_path.split("/")[1], cfg.eval_load_path.split("/")[2], ) eval_on = "-".join(cfg.exp_domains) if data == "test": eval_on += "_test" if not os.path.exists(cfg.log_path): os.mkdir(cfg.log_path) log_file_name = os.path.join(cfg.log_path, model_setting + "-" + eval_on + ".json") if os.path.exists(log_file_name): eval_to_json = json.load(open(log_file_name, "r")) eval_to_json[epoch_setting] = eval_results json.dump(eval_to_json, open(log_file_name, "w"), indent=2) else: eval_to_json = {} eval_to_json[epoch_setting] = eval_results json.dump(eval_to_json, open(log_file_name, "w"), indent=2) logging.info("update eval results to {}".format(log_file_name)) # log predictions table to wandb, giving it a name test_data_at.add(val_table, "predictions") wandb.run.log_artifact(test_data_at) if cfg.mode != "train": # Mark the run as finished on wandb wandb.finish() return eval_results def decode_generated_act_resp(self, generated): """ decode generated return decoded['resp'] ('bspn', 'aspn') """ decoded = {} eos_a_id = self.tokenizer.encode([""])[0] eos_r_id = self.tokenizer.encode([""])[0] # eos_b_id = self.tokenizer.encode([""])[0] # eos_r may not exists if gpt2 generated repetitive words. if eos_r_id in generated: eos_r_idx = generated.index(eos_r_id) else: eos_r_idx = len(generated) - 1 # NOTE: the below logging is commented out because when running evaluation # on early checkpoints of gpt2, the generated response is almost always missing # and it kills the GPU due to constant decoding (plus it swamps the logs) # logging.info('eos_r not in generated: ' + # self.tokenizer.decode(generated)) if cfg.use_true_curr_aspn: # only predict resp decoded["resp"] = generated[: eos_r_idx + 1] else: # predicted aspn, resp eos_a_idx = generated.index(eos_a_id) decoded["aspn"] = generated[: eos_a_idx + 1] decoded["resp"] = generated[eos_a_idx + 1 : eos_r_idx + 1] # if cfg.use_true_curr_bspn: # else: # predict bspn aspn resp # eos_b_idx = generated.index(eos_b_id) # eos_a_idx = generated.index(eos_a_id) # decoded['bspn'] = generated[: eos_b_idx+1] # decoded['aspn'] = generated[eos_b_idx+1: eos_a_idx+1] # decoded['resp'] = generated[eos_a_idx+1: eos_r_idx+1] return decoded def decode_generated_bspn(self, generated): eos_b_id = self.tokenizer.encode([""])[0] if eos_b_id in generated: eos_b_idx = generated.index(eos_b_id) else: eos_b_idx = len(generated) - 1 return generated[: eos_b_idx + 1] def parse_arg_cfg(args): # add args to cfg if args.cfg: for pair in args.cfg: k, v = tuple(pair.split("=")) dtype = type(getattr(cfg, k)) if dtype == type(None): raise ValueError() if dtype is bool: v = False if v == "False" else True elif dtype is list: v = v.split(",") if k == "cuda_device": v = [int(no) for no in v] else: v = dtype(v) setattr(cfg, k, v) return def main(): if not os.path.exists("./models/UBAR/experiments"): os.mkdir("./models/UBAR/experiments") if not os.path.exists("./models/UBAR/experiments_21"): os.mkdir("./models/UBAR/experiments_21") parser = argparse.ArgumentParser() parser.add_argument("-mode") parser.add_argument("-cfg", nargs="*") args = parser.parse_args() cfg.mode = args.mode if args.mode == "test" or args.mode == "adjust": parse_arg_cfg(args) # cfg.model_path = cfg.eval_load_path cfg.gpt_path = cfg.eval_load_path else: # train parse_arg_cfg(args) if cfg.exp_path in ["", "to be generated"]: # log file path, control the factors: seed, learning_rate, batch_size, # early_stop_count, weight decay... cfg.exp_path = 'experiments/ # {}_{}_sd{}_lr{}_bs{}_sp{}_dc{}/'.format('-'.join(cfg.exp_domains), # cfg.exp_no, cfg.seed, cfg.lr, cfg.batch_size, # cfg.early_stop_count, cfg.weight_decay_count) experiments_path = ( "./models/UBAR/experiments" if "all" in cfg.exp_domains else "./models/experiments_Xdomain" ) cfg.exp_path = os.path.join( experiments_path, "{}_{}_sd{}_lr{}_bs{}_ga{}".format( "-".join(cfg.exp_domains), cfg.exp_no, cfg.seed, cfg.lr, cfg.batch_size, cfg.gradient_accumulation_steps, ), ) logging.info("save path:", cfg.exp_path) if cfg.save_log: if not os.path.exists(cfg.exp_path): os.mkdir(cfg.exp_path) # to gpt later cfg.model_path = os.path.join(cfg.exp_path, "model.pkl") cfg.result_path = os.path.join(cfg.exp_path, "result.csv") cfg.vocab_path_eval = os.path.join(cfg.exp_path, "vocab") cfg.eval_load_path = cfg.exp_path cfg._init_logging_handler(args.mode) if cfg.cuda: if len(cfg.cuda_device) == 1: cfg.multi_gpu = False # torch.cuda.set_device(cfg.cuda_device[0]) device = torch.device("cuda:{}".format(cfg.cuda_device[0])) else: pass # multi-gpu else: device = torch.device("cpu") # logging.info('Device: {}'.format(torch.cuda.current_device())) # fix random seed torch.manual_seed(cfg.seed) torch.cuda.manual_seed(cfg.seed) random.seed(cfg.seed) np.random.seed(cfg.seed) # initialize model m = Model(device) if args.mode == "train": # train if cfg.save_log: # save cfg details. pass m.train() else: # test logging.info( "Generate setting: \n\t use true_prev_bspn={} \n\t use true_prev_aspn={} \n\t use true_db_pointer={} \ \n\t use true_prev_resp={} \n\t use true_curr_bspn={} \n\t use true_curr_aspn={} \ \n\t use_all_previous_context={}".format( cfg.use_true_prev_bspn, cfg.use_true_prev_aspn, cfg.use_true_db_pointer, cfg.use_true_prev_resp, cfg.use_true_curr_bspn, cfg.use_true_curr_aspn, cfg.use_all_previous_context, ) ) logging.info("Running eval on test") m.validate(cfg.eval_set) logging.info("Evaluation finished") if __name__ == "__main__": main()