alistairmcleay's picture
Added dialogue system code
b16a132
import json
import random
import sys
import time
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from transformers import (
AdamW,
GPT2Config,
GPT2LMHeadModel,
GPT2Tokenizer,
get_linear_schedule_with_warmup,
)
import wandb
from crazyneuraluser.user_model_code.argument import get_args
# from interact import interact
from crazyneuraluser.user_model_code.dataset import SGD_Dataset
from crazyneuraluser.user_model_code.utils_generation import decode_e2e
from crazyneuraluser.user_model_code.utils_sgd import get_special_tokens
def print_loss(epoch, data_type, LOSS, t0):
print(
"Epoch: {} | {} loss: {:.3f} | time: {:.1f}".format(
epoch, data_type, LOSS, time.time() - t0
)
)
def print_score(epoch, data_type, res, t0):
print(
"Epoch: {} | {}: joint_acc: {:.2f}%, slot_acc: {:.2f}% | time: {:.1f}".format(
epoch,
data_type,
res["avg_joint_acc"],
res["avg_slot_acc"],
time.time() - t0,
)
)
def run_one_epoch(data_type, dataloader, trainer, epoch, run_type, collector=None):
t0 = time.time()
assert data_type in ["dev", "test"]
assert run_type in ["teacher_force", "generation"]
model, optimizer, scheduler, tokenizer = trainer
LOSS = 0
# result = {"slot_acc": [], "joint_acc": []}
# mention_match = 0
# coref_lines = []
iterator = enumerate(
tqdm(
dataloader,
desc="Epoch {} {}".format(epoch, run_type),
disable=args.disable_display,
)
)
for step, batch in iterator:
if run_type == "teacher_force":
loss, logits, _ = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
token_type_ids=batch["token_type_ids"],
labels=batch["label_ids"],
).values()
LOSS += loss
else:
decode_e2e(args, batch, model, tokenizer, collector=collector)
# print log
if run_type == "teacher_force":
LOSS /= step + 1
print_loss(epoch, data_type, LOSS, t0)
return LOSS
else: # generation
# TODO: add evaluation code here
return None
def set_dataloader(args, tokenizer, data_type, run_type, data_size=-1):
dataset = SGD_Dataset(
args, tokenizer, data_type, run_type == "generation", data_size
)
# sys.exit(1)
if data_type == "train":
sampler = RandomSampler(
dataset
) # if args.local_rank == -1 else DistributedSampler(train_dataset)
else:
sampler = SequentialSampler(dataset)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=args.train_batch_size
if data_type == "train"
else args.eval_batch_size,
collate_fn=dataset.collate_fn,
)
return dataloader
def train(args, tokenizer, model):
wandb.init(
# Set the project where this run will be logged
project="E2E User Simulator (Alistair)",
entity="byrne-lab",
# We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
name=args.wandb_train_run_name,
# Track hyperparameters and run metadata
config={
"data_dir": args.data_dir,
"model_name": args.model_name,
"learning_rate": args.learning_rate,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"train_batch_size": args.train_batch_size,
"eval_batch_size": args.eval_batch_size,
},
)
# load data
train_dataloader = set_dataloader(
args, tokenizer, "train", "teacher_force", data_size=args.train_size
)
dev_dataloader = set_dataloader(
args, tokenizer, "dev", "teacher_force", data_size=args.eval_size
)
optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
if args.use_scheduler:
t_total = (
len(train_dataloader) // args.gradient_accumulation_steps * args.max_epoch
)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
else:
scheduler = None
trainer = (model, optimizer, scheduler, tokenizer)
print("Do evaluation before training!")
model.eval()
with torch.no_grad():
_ = run_one_epoch("dev", dev_dataloader, trainer, -1, "teacher_force")
print("Start training!\n{}".format("***" * 30))
eval_step = args.eval_interval // args.train_batch_size
best_score = -100
global_step = 0
no_improve_count = 0
for epoch in range(args.max_epoch):
# initialize for each epoch training
t0 = time.time()
model.train()
model.zero_grad()
LOSS = 0
iterator = enumerate(
tqdm(
train_dataloader,
desc="Epoch {}".format(epoch),
disable=args.disable_display,
)
)
for local_step, batch in iterator:
loss, logits, _ = model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
token_type_ids=batch["token_type_ids"],
labels=batch["label_ids"],
).values()
LOSS += loss
global_step += 1
wandb.log({"loss": loss})
# update model
if loss != 0:
loss = loss / args.gradient_accumulation_steps
loss.backward()
if global_step % args.gradient_accumulation_steps == 0:
# norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
if args.use_scheduler:
scheduler.step()
optimizer.zero_grad()
# evaluate model
if global_step % eval_step == 0:
model.eval()
with torch.no_grad():
loss = run_one_epoch(
"dev", dev_dataloader, trainer, epoch, "teacher_force"
)
score = -loss # dev loss as criterion for early training
wandb.log({"dev_loss": loss})
model.train()
save_checkpoint(
args, tokenizer, model, global_step * args.train_batch_size
)
if score > best_score:
best_score = score
print("Best score: {:.2f}".format(best_score))
no_improve_count = 0
else:
no_improve_count += 1
# early stop
if no_improve_count == args.no_improve_max:
print("Early stop!")
return
LOSS /= local_step + 1
print_loss(epoch, "train", LOSS, t0)
print("***" * 30)
wandb.log({"epoch": epoch, "epoch_loss": LOSS})
# Mark the run as finished on wandb
wandb.finish()
def test(args, tokenizer, model):
# load data
test_gen_dataloader = set_dataloader(args, tokenizer, "test", "generation")
trainer = (model, None, None, tokenizer)
model.eval()
collector = {"decode-dev": {}, "decode-test": {}}
with torch.no_grad():
# # evaluate on dev
# _ = run_one_epoch('dev', dev_dataloader, trainer, 'Eval', 'teacher_force')
# # generate on dev
# res_dev = run_one_epoch('dev', dev_gen_dataloader, trainer, 'Dev', 'generation',
# collector=collector['decode-dev'])
# collector['result-dev'] = res_dev
# print_qr_result(res_dev['qr'], 'dev')
# generate on test
res_test = run_one_epoch(
"test",
test_gen_dataloader,
trainer,
"Test",
"generation",
collector=collector["decode-test"],
)
collector["result-test"] = res_test
out_file = args.decode_file
with open(out_file, "w") as f:
json.dump(collector, f, indent=4, sort_keys=True)
print("Decode file is saved at {}".format(out_file))
print("Done decoding!")
def save_checkpoint(args, tokenizer, model, step):
save_path = args.checkpoint + "_step" + str(step)
print("Save model in {}!".format(save_path))
tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)
def load_checkpoint(args):
save_path = args.checkpoint # + '_step' + str(args.step)
print("Load model, tokenizer from {}".format(save_path))
tokenizer = GPT2Tokenizer.from_pretrained(save_path)
model = GPT2LMHeadModel.from_pretrained(save_path)
model.to(args.device)
return tokenizer, model
def load_pretrained_model(args):
save_path = args.pre_checkpoint
print("Load model, tokenizer from {}".format(save_path))
tokenizer = GPT2Tokenizer.from_pretrained(save_path)
model = GPT2LMHeadModel.from_pretrained(save_path)
model.to(args.device)
return tokenizer, model
def set_model(args, SPECIAL_TOKENS):
"""initiate config, tokenizer and model"""
# add special tokens into tokenizer
config = GPT2Config.from_pretrained(args.model_name_or_path)
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
tokenizer.add_special_tokens(SPECIAL_TOKENS)
model = GPT2LMHeadModel.from_pretrained(
args.model_name_or_path, config=config
) # GPT2LMHeadModel
model.resize_token_embeddings(len(tokenizer))
model.to(args.device)
print("Done setting model")
return config, tokenizer, model
def set_seed(args):
"""for reproduction"""
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
if __name__ == "__main__":
# Load arguments
args = get_args()
# Set seed, device
set_seed(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device
# Load special tokens
SPECIAL_TOKENS = get_special_tokens()
if args.mode == "training":
config, tokenizer, model = set_model(args, SPECIAL_TOKENS)
train(args, tokenizer, model)
elif args.mode == "finetune":
tokenizer, model = load_pretrained_model(args)
train(args, tokenizer, model)
elif args.mode == "testing":
tokenizer, model = load_checkpoint(args)
test(args, tokenizer, model)
# elif args.mode == 'interact':
# tokenizer, model = load_checkpoint(args)
# interact(args, tokenizer, model)
else:
sys.exit(1)