|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from contextlib import nullcontext |
|
|
|
|
|
|
|
import torch |
|
from torch.nn.utils import clip_grad_norm_ |
|
|
|
|
|
class Executor: |
|
def __init__(self): |
|
self.step = 0 |
|
|
|
def train( |
|
self, model, optimizer, scheduler, data_loader, device, writer, args, scaler |
|
): |
|
"""Train one epoch""" |
|
model.train() |
|
clip = args.get("grad_clip", 50.0) |
|
log_interval = args.get("log_interval", 10) |
|
rank = args.get("rank", 0) |
|
epoch = args.get("epoch", 0) |
|
accum_grad = args.get("accum_grad", 1) |
|
is_distributed = args.get("is_distributed", True) |
|
use_amp = args.get("use_amp", False) |
|
logging.info( |
|
"using accumulate grad, new batch size is {} times" |
|
" larger than before".format(accum_grad) |
|
) |
|
if use_amp: |
|
assert scaler is not None |
|
|
|
|
|
|
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
model_context = model.join |
|
else: |
|
model_context = nullcontext |
|
num_seen_utts = 0 |
|
with model_context(): |
|
for batch_idx, batch in enumerate(data_loader): |
|
key, feats, target, feats_lengths, target_lengths = batch |
|
feats = feats.to(device) |
|
target = target.to(device) |
|
feats_lengths = feats_lengths.to(device) |
|
target_lengths = target_lengths.to(device) |
|
num_utts = target_lengths.size(0) |
|
if num_utts == 0: |
|
continue |
|
context = None |
|
|
|
|
|
|
|
if is_distributed and batch_idx % accum_grad != 0: |
|
context = model.no_sync |
|
|
|
|
|
else: |
|
context = nullcontext |
|
with context(): |
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(scaler is not None): |
|
loss_dict = model(feats, feats_lengths, target, target_lengths) |
|
loss = loss_dict["loss"] / accum_grad |
|
if use_amp: |
|
scaler.scale(loss).backward() |
|
else: |
|
loss.backward() |
|
|
|
num_seen_utts += num_utts |
|
if batch_idx % accum_grad == 0: |
|
if rank == 0 and writer is not None: |
|
writer.add_scalar("train_loss", loss, self.step) |
|
|
|
if use_amp: |
|
scaler.unscale_(optimizer) |
|
grad_norm = clip_grad_norm_(model.parameters(), clip) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
else: |
|
grad_norm = clip_grad_norm_(model.parameters(), clip) |
|
if torch.isfinite(grad_norm): |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
scheduler.step() |
|
self.step += 1 |
|
if batch_idx % log_interval == 0: |
|
lr = optimizer.param_groups[0]["lr"] |
|
log_str = "TRAIN Batch {}/{} loss {:.6f} ".format( |
|
epoch, batch_idx, loss.item() * accum_grad |
|
) |
|
for name, value in loss_dict.items(): |
|
if name != "loss" and value is not None: |
|
log_str += "{} {:.6f} ".format(name, value.item()) |
|
log_str += "lr {:.8f} rank {}".format(lr, rank) |
|
logging.debug(log_str) |
|
|
|
def cv(self, model, data_loader, device, args): |
|
"""Cross validation on""" |
|
model.eval() |
|
rank = args.get("rank", 0) |
|
epoch = args.get("epoch", 0) |
|
log_interval = args.get("log_interval", 10) |
|
|
|
num_seen_utts = 1 |
|
total_loss = 0.0 |
|
with torch.no_grad(): |
|
for batch_idx, batch in enumerate(data_loader): |
|
key, feats, target, feats_lengths, target_lengths = batch |
|
feats = feats.to(device) |
|
target = target.to(device) |
|
feats_lengths = feats_lengths.to(device) |
|
target_lengths = target_lengths.to(device) |
|
num_utts = target_lengths.size(0) |
|
if num_utts == 0: |
|
continue |
|
loss_dict = model(feats, feats_lengths, target, target_lengths) |
|
loss = loss_dict["loss"] |
|
if torch.isfinite(loss): |
|
num_seen_utts += num_utts |
|
total_loss += loss.item() * num_utts |
|
if batch_idx % log_interval == 0: |
|
log_str = "CV Batch {}/{} loss {:.6f} ".format( |
|
epoch, batch_idx, loss.item() |
|
) |
|
for name, value in loss_dict.items(): |
|
if name != "loss" and value is not None: |
|
log_str += "{} {:.6f} ".format(name, value.item()) |
|
log_str += "history loss {:.6f}".format(total_loss / num_seen_utts) |
|
log_str += " rank {}".format(rank) |
|
logging.debug(log_str) |
|
return total_loss, num_seen_utts |
|
|