alistairmcleay's picture
Added dialogue system code
b16a132
import argparse
import json
import logging
import os
import random
import time
import warnings
import numpy as np
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
import wandb
from crazyneuraluser.UBAR_code.config import global_config as cfg
from crazyneuraluser.UBAR_code.eval import MultiWozEvaluator
from crazyneuraluser.UBAR_code.reader import MultiWozReader
# from config21 import global_config as cfg # global, already initialized
warnings.filterwarnings("ignore")
class Model(object):
def __init__(self, device):
self.device = device
# initialize tokenizer
self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path)
# cfg.tokenizer = tokenizer
# initialize multiwoz reader
self.reader = MultiWozReader(self.tokenizer)
# create model: gpt2
self.model = GPT2LMHeadModel.from_pretrained(cfg.gpt_path)
if cfg.mode == "train":
self.model.resize_token_embeddings(len(self.tokenizer))
self.model.to(self.device) # single gpu
#
self.evaluator = MultiWozEvaluator(self.reader)
if cfg.save_log and cfg.mode == "train":
self.tb_writer = SummaryWriter(log_dir="./log")
else:
self.tb_writer = None
def get_optimizers(self):
"""
Setup the optimizer and the learning rate scheduler.
from transformers.Trainer
parameters from cfg: lr (1e-3); warmup_steps
"""
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": cfg.weight_decay,
},
{
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.lr)
num_training_steps = (
self.reader.set_stats["train"]["num_dials"]
* cfg.epoch_num
// (cfg.gradient_accumulation_steps * cfg.batch_size)
)
num_warmup_steps = cfg.warmup_steps if cfg.warmup_steps >= 0 else int(num_training_steps * 0.2)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
)
return optimizer, scheduler
def log_first_inputs(self, inputs):
tokenizer = self.tokenizer
logging.info("**** Input Examples: ****")
for context in inputs["contexts"][:4]:
# ubar = tokenizer.convert_ids_to_tokens(context)
# ubar = tokenizer.convert_tokens_to_string(context)
# ubar = " ".join(ubar)
ubar = tokenizer.decode(context)
logging.info(ubar)
def add_torch_input(self, inputs):
# to tensor and to device
contexts_tensor = torch.from_numpy(inputs["contexts_np"]).long()
contexts_tensor = contexts_tensor.to(self.device)
inputs["contexts_tensor"] = contexts_tensor
return inputs
def add_torch_input_eval(self, inputs):
# inputs: context
inputs["context_tensor"] = torch.tensor([inputs["context"]]).to(self.device)
return inputs
def calculate_loss_and_accuracy(self, outputs, labels):
# GPT2-chicahat/train.py
lm_logits = outputs[0]
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
pad_id = cfg.pad_id
loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction="sum")
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# avg loss
not_ignore = shift_labels.ne(pad_id)
num_targets = not_ignore.long().sum().item()
loss /= num_targets
return loss
def train(self):
"""
UBARU
"""
wandb.init(
# Set the project where this run will be logged
project="E2E User Simulator (Alistair)",
entity="byrne-lab",
# We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
name=cfg.wandb_train_run_name,
# Track hyperparameters and run metadata
config={
"dataset": cfg.data_path,
"gpt_path": cfg.gpt_path,
"learning_rate": cfg.lr,
"warmup_steps": cfg.warmup_steps,
"gradient_accumulation_steps": cfg.gradient_accumulation_steps,
"batch_size": cfg.batch_size,
"epochs": cfg.epoch_num,
},
)
all_batches = self.reader.get_batches("train")
# compute num_training_steps in get_batches()
optimizer, scheduler = self.get_optimizers()
# log info
set_stats = self.reader.set_stats["train"]
logging.info("***** Running training *****")
logging.info(
" Num Training steps(one turn in a batch of dialogs) per epoch = %d",
set_stats["num_training_steps_per_epoch"],
)
logging.info(" Num Turns = %d", set_stats["num_turns"])
logging.info(" Num Dialogs = %d", set_stats["num_dials"])
logging.info(" Num Epochs = %d", cfg.epoch_num)
logging.info(" Batch size = %d", cfg.batch_size)
logging.info(" Gradient Accumulation steps = %d", cfg.gradient_accumulation_steps)
logging.info(
" Total optimization steps = %d",
set_stats["num_dials"] * cfg.epoch_num // (cfg.gradient_accumulation_steps * cfg.batch_size),
)
# tb writer
if self.tb_writer is not None:
self.tb_writer.add_text("cfg", json.dumps(cfg.__dict__, indent=2))
# self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
log_inputs = 2
global_step = 0
# sw = time.time()
for epoch in range(cfg.epoch_num):
epoch_step = 0
tr_loss = 0.0
logging_loss = 0.0
btm = time.time()
oom_time = 0
self.model.zero_grad()
data_iterator = self.reader.get_nontranspose_data_iterator(all_batches)
for batch_idx, dial_batch in enumerate(data_iterator):
inputs = self.reader.convert_batch_session(dial_batch)
try: # avoid OOM
self.model.train()
if log_inputs > 0: # log inputs for the very first two turns
self.log_first_inputs(inputs)
log_inputs -= 1
# to tensor
inputs = self.add_torch_input(inputs)
# loss
outputs = self.model(inputs["contexts_tensor"])
# outputs = self.model(inputs['contexts_tensor']) # debugging with GPT2Model
loss = self.calculate_loss_and_accuracy(outputs, labels=inputs["contexts_tensor"])
loss.backward()
tr_loss += loss.item()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
epoch_step += 1
# step, wrt gradient_accumulation_steps, clip grad norm
if (epoch_step + 1) % cfg.gradient_accumulation_steps == 0 or (
# end of an epoch
(epoch_step + 1)
== set_stats["num_training_steps_per_epoch"]
):
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# global_step: actual step the optimizer took
global_step += 1
logs = {} # for tb writer
# logging: loss, lr... after certain amount of steps
if cfg.report_interval > 0 and global_step % cfg.report_interval == 0:
loss_scalar = (tr_loss - logging_loss) / cfg.report_interval
logging_loss = tr_loss
logs["loss"] = loss_scalar
logging.info(
"Global step: {}, epoch step: {}, interval loss: {:.4f}".format(
global_step, epoch_step, loss_scalar
)
)
# validate
# add to tensorboard...
if cfg.evaluate_during_training and loss_scalar < 10:
results = self.validate(epoch)
for k, v in results.items():
eval_key = "eval_{}".format(k)
logs[eval_key] = v
if self.tb_writer:
for k, v in logs.items():
self.tb_writer.add_scalar(k, v, global_step)
# save model...
except RuntimeError as exception:
if "out of memory" in str(exception):
max_length = max(inputs["lengths"])
oom_time += 1
logging.info(
"WARNING: ran out of memory,times: {}, batch size: {}, max_len: {}".format(
oom_time, cfg.batch_size, max_length
)
)
if hasattr(torch.cuda, "empty_cache"):
torch.cuda.empty_cache()
else:
logging.info(str(exception))
raise exception
logging.info("Train epoch time: {:.2f} min, epoch loss: {:.4f}".format((time.time() - btm) / 60, tr_loss))
# save model after every epoch
# if epoch > 10 or tr_loss/epoch_step < 1:
self.save_model(epoch, tr_loss / epoch_step)
wandb.log({"epoch loss": tr_loss})
# Mark the run as finished on wandb
wandb.finish()
def save_model(self, epoch, loss):
save_path = os.path.join(cfg.exp_path, "epoch{}_trloss{:.2f}_gpt2".format(epoch + 1, loss))
if not os.path.exists(save_path):
os.mkdir(save_path)
logging.info("Saving model checkpoint to %s", save_path)
# save gpt2
self.model.save_pretrained(save_path)
# save tokenizer
self.tokenizer.save_pretrained(save_path)
# save cfg
def validate(self, data="dev", do_test=False, epoch=0):
if cfg.mode != "train":
wandb.init(
# Set the project where this run will be logged
project="E2E User Simulator (Alistair)",
entity="byrne-lab",
# We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
name=cfg.wandb_eval_run_name,
# Track hyperparameters and run metadata
config={
"eval_load_path": cfg.eval_load_path,
"dataset": cfg.data_path,
"gpt_path": cfg.gpt_path,
"learning_rate": cfg.lr,
"warmup_steps": cfg.warmup_steps,
"gradient_accumulation_steps": cfg.gradient_accumulation_steps,
"batch_size": cfg.batch_size,
"epochs": cfg.epoch_num,
"data": data,
},
)
test_data_at = wandb.Artifact(str(wandb.run.id + str(epoch)), type="predictions")
# Create your W&B Table
column_names = [
"dialog",
"turn_num",
"turn_domain",
"pointer",
"user",
"usdx",
"resp",
"bspn",
"bsdx",
"aspn",
"dspn",
"db",
"resp_gen",
"bspn_gen",
"aspn_gen",
"dspn_gen",
]
val_table = wandb.Table(columns=column_names)
# predict one dialog/ one turn at a time
self.model.eval()
# all_batches = self.reader.get_batches('dev')
# data_iterator = self.reader.get_data_iterator(all_batches)
eval_data = self.reader.get_eval_data(data)
set_stats = self.reader.set_stats[data]
logging.info("***** Running Evaluation *****")
logging.info(" Num Turns = %d", set_stats["num_turns"])
# logging.info(" Num Dialogs = %d", set_stats['num_dials'])
# valid_losses = []
btm = time.time()
result_collection = {}
with torch.no_grad():
# Adding this index to allow for quick testing of evaluation
dialogues_to_run = 1
for dial_idx, dialog in tqdm(enumerate(eval_data)):
if dialogues_to_run == 0:
break
dialogues_to_run -= 1
pv_turn = {}
for turn_idx, turn in enumerate(dialog):
first_turn = turn_idx == 0
inputs = self.reader.convert_turn_eval(turn, pv_turn, first_turn)
inputs = self.add_torch_input_eval(inputs)
# fail to generate new tokens, if max_length not set
context_length = len(inputs["context"])
if cfg.use_true_curr_bspn: # generate act, response
max_len = 60
if not cfg.use_true_curr_aspn:
max_len = 80
outputs = self.model.generate(
input_ids=inputs["context_tensor"],
max_length=context_length + max_len,
temperature=0.7, # top_p=0.9, num_beams=4,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
)
# no_repeat_ngram_size=4
# turn['generated'] = self.tokenizer.decode(outputs[0])
# resp_gen, need to trim previous context
generated = outputs[0].cpu().numpy().tolist()
generated = generated[context_length - 1 :]
try:
decoded = self.decode_generated_act_resp(generated)
except ValueError as exception:
logging.info(str(exception))
logging.info(self.tokenizer.decode(generated))
decoded = {"resp": [], "bspn": [], "aspn": []}
else: # predict bspn, access db, then generate act and resp
outputs = self.model.generate(
input_ids=inputs["context_tensor"],
max_length=context_length + 60,
temperature=0.7, # top_p=0.9, num_beams=4,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.encode(["<eos_b>"])[0],
)
generated_bs = outputs[0].cpu().numpy().tolist()
# generated_bs = generated_bs[context_length-1:]
bspn_gen = self.decode_generated_bspn(generated_bs[context_length - 1 :])
# check DB result
if cfg.use_true_db_pointer:
# db_result = self.reader.bspan_to_DBpointer(
# self.tokenizer.decode(turn['bspn']), turn['turn_domain'])
db = turn["db"]
else:
db_result = self.reader.bspan_to_DBpointer(
self.tokenizer.decode(bspn_gen), turn["turn_domain"]
)
db = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize("<sos_db> " + db_result + " <eos_db>")
) + self.tokenizer.encode(["<sos_a>"])
inputs["context_tensor_db"] = torch.tensor([inputs["context"][:-1] + bspn_gen + db]).to(
self.device
)
context_length = len(inputs["context_tensor_db"][0])
outputs_db = self.model.generate(
input_ids=inputs["context_tensor_db"],
max_length=context_length + 80,
temperature=0.7, # top_p=0.9, num_beams=4,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
)
generated_ar = outputs_db[0].cpu().numpy().tolist()
generated_ar = generated_ar[context_length - 1 :]
try:
decoded = self.decode_generated_act_resp(generated_ar)
decoded["bspn"] = bspn_gen
except ValueError:
# NOTE: the below logging is commented out because when running evaluation
# on early checkpoints of gpt2, the generated response is almost always
# missing <eos_b> and it kills the GPU due to constant decoding (plus it swamps the logs)
# logging.info(str(exception))
# logging.info(self.tokenizer.decode(generated_ar))
decoded = {"resp": [], "bspn": [], "aspn": []}
turn["resp_gen"] = decoded["resp"]
turn["bspn_gen"] = turn["bspn"] if cfg.use_true_curr_bspn else decoded["bspn"]
turn["aspn_gen"] = turn["aspn"] if cfg.use_true_curr_aspn else decoded["aspn"]
turn["dspn_gen"] = turn["dspn"]
# check DB results
# db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']),
# turn['turn_domain'])
# if db_result[0] == 1: # no match
# print('gt:', self.tokenizer.decode(turn['aspn']), '
# |gen:', self.tokenizer.decode(decoded['aspn']))
# print('gen_resp: ', self.tokenizer.decode(decoded['resp']))
# print('gt_resp: ', self.tokenizer.decode(turn['resp']), '\n')
# all true previous context
pv_turn["labels"] = inputs["labels"]
pv_turn["resp"] = turn["resp"] if cfg.use_true_prev_resp else decoded["resp"]
pv_turn["bspn"] = turn["bspn"] if cfg.use_true_prev_bspn else decoded["bspn"]
pv_turn["db"] = turn["db"] if cfg.use_true_curr_bspn else db
pv_turn["aspn"] = turn["aspn"] if cfg.use_true_prev_aspn else decoded["aspn"]
turn_result = self.reader.inverse_transpose_turn(dialog)
result_collection.update(turn_result)
for dialog, turns in turn_result.items():
for turn in turns:
curr_turn_plain = [
dialog,
turn["turn_num"],
turn["turn_domain"],
turn["pointer"],
]
curr_turn_tokenised = [
self.tokenizer.decode(turn[key])
for key in turn.keys()
if key != "pointer" and key != "turn_domain" and key != "turn_num"
]
curr_turn_data = curr_turn_plain + curr_turn_tokenised
val_table.add_data(*curr_turn_data)
logging.info("inference time: {:.2f} min".format((time.time() - btm) / 60))
# score
btm = time.time()
results, _ = self.reader.wrap_result_lm(result_collection)
bleu, success, match = self.evaluator.validation_metric(results)
logging.info("Scoring time: {:.2f} min".format((time.time() - btm) / 60))
score = 0.5 * (success + match) + bleu
# valid_loss = 130 - score
logging.info(
"validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % (match, success, bleu, score)
)
eval_results = {}
eval_results["bleu"] = bleu
eval_results["success"] = success
eval_results["match"] = match
eval_results["score"] = score
eval_results["result"] = "validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % (
match,
success,
bleu,
score,
)
wandb.log(
{
"bleu": eval_results["bleu"],
"success": eval_results["success"],
"match": eval_results["match"],
"score": eval_results["score"],
}
)
model_setting, epoch_setting = (
cfg.eval_load_path.split("/")[1],
cfg.eval_load_path.split("/")[2],
)
eval_on = "-".join(cfg.exp_domains)
if data == "test":
eval_on += "_test"
if not os.path.exists(cfg.log_path):
os.mkdir(cfg.log_path)
log_file_name = os.path.join(cfg.log_path, model_setting + "-" + eval_on + ".json")
if os.path.exists(log_file_name):
eval_to_json = json.load(open(log_file_name, "r"))
eval_to_json[epoch_setting] = eval_results
json.dump(eval_to_json, open(log_file_name, "w"), indent=2)
else:
eval_to_json = {}
eval_to_json[epoch_setting] = eval_results
json.dump(eval_to_json, open(log_file_name, "w"), indent=2)
logging.info("update eval results to {}".format(log_file_name))
# log predictions table to wandb, giving it a name
test_data_at.add(val_table, "predictions")
wandb.run.log_artifact(test_data_at)
if cfg.mode != "train":
# Mark the run as finished on wandb
wandb.finish()
return eval_results
def decode_generated_act_resp(self, generated):
"""
decode generated
return decoded['resp'] ('bspn', 'aspn')
"""
decoded = {}
eos_a_id = self.tokenizer.encode(["<eos_a>"])[0]
eos_r_id = self.tokenizer.encode(["<eos_r>"])[0]
# eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
# eos_r may not exists if gpt2 generated repetitive words.
if eos_r_id in generated:
eos_r_idx = generated.index(eos_r_id)
else:
eos_r_idx = len(generated) - 1
# NOTE: the below logging is commented out because when running evaluation
# on early checkpoints of gpt2, the generated response is almost always missing
# <eos_r> and it kills the GPU due to constant decoding (plus it swamps the logs)
# logging.info('eos_r not in generated: ' +
# self.tokenizer.decode(generated))
if cfg.use_true_curr_aspn: # only predict resp
decoded["resp"] = generated[: eos_r_idx + 1]
else: # predicted aspn, resp
eos_a_idx = generated.index(eos_a_id)
decoded["aspn"] = generated[: eos_a_idx + 1]
decoded["resp"] = generated[eos_a_idx + 1 : eos_r_idx + 1]
# if cfg.use_true_curr_bspn:
# else: # predict bspn aspn resp
# eos_b_idx = generated.index(eos_b_id)
# eos_a_idx = generated.index(eos_a_id)
# decoded['bspn'] = generated[: eos_b_idx+1]
# decoded['aspn'] = generated[eos_b_idx+1: eos_a_idx+1]
# decoded['resp'] = generated[eos_a_idx+1: eos_r_idx+1]
return decoded
def decode_generated_bspn(self, generated):
eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
if eos_b_id in generated:
eos_b_idx = generated.index(eos_b_id)
else:
eos_b_idx = len(generated) - 1
return generated[: eos_b_idx + 1]
def parse_arg_cfg(args):
# add args to cfg
if args.cfg:
for pair in args.cfg:
k, v = tuple(pair.split("="))
dtype = type(getattr(cfg, k))
if dtype == type(None):
raise ValueError()
if dtype is bool:
v = False if v == "False" else True
elif dtype is list:
v = v.split(",")
if k == "cuda_device":
v = [int(no) for no in v]
else:
v = dtype(v)
setattr(cfg, k, v)
return
def main():
if not os.path.exists("./models/UBAR/experiments"):
os.mkdir("./models/UBAR/experiments")
if not os.path.exists("./models/UBAR/experiments_21"):
os.mkdir("./models/UBAR/experiments_21")
parser = argparse.ArgumentParser()
parser.add_argument("-mode")
parser.add_argument("-cfg", nargs="*")
args = parser.parse_args()
cfg.mode = args.mode
if args.mode == "test" or args.mode == "adjust":
parse_arg_cfg(args)
# cfg.model_path = cfg.eval_load_path
cfg.gpt_path = cfg.eval_load_path
else: # train
parse_arg_cfg(args)
if cfg.exp_path in ["", "to be generated"]:
# log file path, control the factors: seed, learning_rate, batch_size,
# early_stop_count, weight decay... cfg.exp_path = 'experiments/
# {}_{}_sd{}_lr{}_bs{}_sp{}_dc{}/'.format('-'.join(cfg.exp_domains),
# cfg.exp_no, cfg.seed, cfg.lr, cfg.batch_size,
# cfg.early_stop_count, cfg.weight_decay_count)
experiments_path = (
"./models/UBAR/experiments" if "all" in cfg.exp_domains else "./models/experiments_Xdomain"
)
cfg.exp_path = os.path.join(
experiments_path,
"{}_{}_sd{}_lr{}_bs{}_ga{}".format(
"-".join(cfg.exp_domains),
cfg.exp_no,
cfg.seed,
cfg.lr,
cfg.batch_size,
cfg.gradient_accumulation_steps,
),
)
logging.info("save path:", cfg.exp_path)
if cfg.save_log:
if not os.path.exists(cfg.exp_path):
os.mkdir(cfg.exp_path)
# to gpt later
cfg.model_path = os.path.join(cfg.exp_path, "model.pkl")
cfg.result_path = os.path.join(cfg.exp_path, "result.csv")
cfg.vocab_path_eval = os.path.join(cfg.exp_path, "vocab")
cfg.eval_load_path = cfg.exp_path
cfg._init_logging_handler(args.mode)
if cfg.cuda:
if len(cfg.cuda_device) == 1:
cfg.multi_gpu = False
# torch.cuda.set_device(cfg.cuda_device[0])
device = torch.device("cuda:{}".format(cfg.cuda_device[0]))
else:
pass # multi-gpu
else:
device = torch.device("cpu")
# logging.info('Device: {}'.format(torch.cuda.current_device()))
# fix random seed
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed(cfg.seed)
random.seed(cfg.seed)
np.random.seed(cfg.seed)
# initialize model
m = Model(device)
if args.mode == "train": # train
if cfg.save_log: # save cfg details.
pass
m.train()
else: # test
logging.info(
"Generate setting: \n\t use true_prev_bspn={} \n\t use true_prev_aspn={} \n\t use true_db_pointer={} \
\n\t use true_prev_resp={} \n\t use true_curr_bspn={} \n\t use true_curr_aspn={} \
\n\t use_all_previous_context={}".format(
cfg.use_true_prev_bspn,
cfg.use_true_prev_aspn,
cfg.use_true_db_pointer,
cfg.use_true_prev_resp,
cfg.use_true_curr_bspn,
cfg.use_true_curr_aspn,
cfg.use_all_previous_context,
)
)
logging.info("Running eval on test")
m.validate(cfg.eval_set)
logging.info("Evaluation finished")
if __name__ == "__main__":
main()