haoyuliu00's picture
Initial commit with cleaned history
bf8981a
import torch
import json
import torch.nn as nn
from tqdm import tqdm
from torch.utils.tensorboard.writer import SummaryWriter
from typing import Optional
import os
def nested_map(struct, map_fn):
"""This is for trasfering into cuda device"""
if isinstance(struct, tuple):
return tuple(nested_map(x, map_fn) for x in struct)
if isinstance(struct, list):
return [nested_map(x, map_fn) for x in struct]
if isinstance(struct, dict):
return {k: nested_map(v, map_fn) for k, v in struct.items()}
return map_fn(struct)
class DiffproLearner:
def __init__(
self, output_dir, model, train_dl, val_dl, optimizer, params
):
# model output
self.output_dir = output_dir
self.log_dir = f"{output_dir}/logs"
self.checkpoint_dir = f"{output_dir}/chkpts"
# model (architecture and loss)
self.model = model
# data loader
self.train_dl = train_dl
self.val_dl = val_dl
# optimizer
self.optimizer = optimizer
# what is this ????
self.params = params
# current time recoder
self.step = 0
self.epoch = 0
self.grad_norm = 0.
# other information
self.summary_writer = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.autocast = torch.cuda.amp.autocast(enabled=params.fp16)
self.scaler = torch.cuda.amp.GradScaler(enabled=params.fp16)
self.best_val_loss = torch.tensor([1e10], device=self.device)
# restore if directory exists
if os.path.exists(self.output_dir):
self.restore_from_checkpoint()
else:
os.makedirs(self.output_dir)
os.makedirs(self.log_dir)
os.makedirs(self.checkpoint_dir)
with open(f"{output_dir}/params.json", "w") as params_file:
json.dump(self.params, params_file)
print(json.dumps(self.params, sort_keys=True, indent=4))
def _write_summary(self, losses: dict, scheduled_params: Optional[dict], type):
"""type: train or val"""
summary_losses = losses
summary_losses["grad_norm"] = self.grad_norm
if scheduled_params is not None:
for k, v in scheduled_params.items():
summary_losses[f"sched_{k}"] = v
writer = self.summary_writer or SummaryWriter(
self.log_dir, purge_step=self.step
)
writer.add_scalars(type, summary_losses, self.step)
writer.flush()
self.summary_writer = writer
def state_dict(self):
# state dictionary
model_state = self.model.state_dict()
return {
"step": self.step,
"epoch": self.epoch,
"model":
{
k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in model_state.items()
},
"optimizer":
{
k: v.cpu() if isinstance(v, torch.Tensor) else v
for k, v in self.optimizer.state_dict().items()
},
"scaler": self.scaler.state_dict(),
}
def load_state_dict(self, state_dict):
self.step = state_dict["step"]
self.epoch = state_dict["epoch"]
self.model.load_state_dict(state_dict["model"])
self.optimizer.load_state_dict(state_dict["optimizer"])
self.scaler.load_state_dict(state_dict["scaler"])
def restore_from_checkpoint(self, fname="weights"):
try:
fpath = f"{self.checkpoint_dir}/{fname}.pt"
checkpoint = torch.load(fpath)
self.load_state_dict(checkpoint)
print(f"Restored from checkpoint {fpath} --> {fname}-{self.epoch}.pt!")
return True
except FileNotFoundError:
print("No checkpoint found. Starting from scratch...")
return False
def _link_checkpoint(self, save_name, link_fpath):
if os.path.islink(link_fpath):
os.unlink(link_fpath)
os.symlink(save_name, link_fpath)
def save_to_checkpoint(self, fname="weights", is_best=False):
save_name = f"{fname}-{self.epoch}.pt"
save_fpath = f"{self.checkpoint_dir}/{save_name}"
link_best_fpath = f"{self.checkpoint_dir}/{fname}_best.pt"
link_fpath = f"{self.checkpoint_dir}/{fname}.pt"
torch.save(self.state_dict(), save_fpath)
self._link_checkpoint(save_name, link_fpath)
if is_best:
self._link_checkpoint(save_name, link_best_fpath)
def train(self, max_epoch=None):
self.model.train()
while True:
self.epoch = self.step // len(self.train_dl)
if max_epoch is not None and self.epoch >= max_epoch:
return
for batch in tqdm(self.train_dl, desc=f"Epoch {self.epoch}"):
#print("type of batch:", type(batch))
batch = nested_map(
batch, lambda x: x.to(self.device)
if isinstance(x, torch.Tensor) else x
)
#print("type of batch:", type(batch))
losses, scheduled_params = self.train_step(batch)
# check NaN
for loss_value in list(losses.values()):
if isinstance(loss_value,
torch.Tensor) and torch.isnan(loss_value).any():
raise RuntimeError(
f"Detected NaN loss at step {self.step}, epoch {self.epoch}"
)
if self.step % 50 == 0:
self._write_summary(losses, scheduled_params, "train")
if self.step % 5000 == 0 and self.step != 0 \
and self.epoch != 0:
self.valid()
self.step += 1
# valid
self.valid()
def valid(self):
# self.model.eval()
losses = None
for batch in self.val_dl:
batch = nested_map(
batch, lambda x: x.to(self.device) if isinstance(x, torch.Tensor) else x
)
current_losses, _ = self.val_step(batch)
losses = losses or current_losses
for k, v in current_losses.items():
losses[k] += v
assert losses is not None
for k, v in losses.items():
losses[k] /= len(self.val_dl)
self._write_summary(losses, None, "val")
if self.best_val_loss >= losses["loss"]:
self.best_val_loss = losses["loss"]
self.save_to_checkpoint(is_best=True)
else:
self.save_to_checkpoint(is_best=False)
def train_step(self, batch):
# people say this is the better way to set zero grad
# instead of self.optimizer.zero_grad()
for param in self.model.parameters():
param.grad = None
# here forward the model
with self.autocast:
scheduled_params = None
loss_dict = self.model.get_loss_dict(batch, self.step)
loss = loss_dict["loss"]
self.scaler.scale(loss).backward()
self.scaler.unscale_(self.optimizer)
self.grad_norm = nn.utils.clip_grad.clip_grad_norm_(
self.model.parameters(), self.params.max_grad_norm or 1e9
)
self.scaler.step(self.optimizer)
self.scaler.update()
return loss_dict, scheduled_params
def val_step(self, batch):
with torch.no_grad():
with self.autocast:
scheduled_params = None
loss_dict = self.model.get_loss_dict(batch, self.step)
return loss_dict, scheduled_params