|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import time |
|
from typing import Any, List, Optional |
|
|
|
import torch |
|
from accelerate import Accelerator |
|
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase |
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase |
|
from pytorch3d.implicitron.models.generic_model import EvaluationMode |
|
from pytorch3d.implicitron.tools import model_io, vis_utils |
|
from pytorch3d.implicitron.tools.config import ( |
|
registry, |
|
ReplaceableBase, |
|
run_auto_creation, |
|
) |
|
from pytorch3d.implicitron.tools.stats import Stats |
|
from torch.utils.data import DataLoader, Dataset |
|
|
|
from .utils import seed_all_random_engines |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
class TrainingLoopBase(ReplaceableBase): |
|
""" |
|
Members: |
|
evaluator: An EvaluatorBase instance, used to evaluate training results. |
|
""" |
|
|
|
evaluator: Optional[EvaluatorBase] |
|
evaluator_class_type: Optional[str] = "ImplicitronEvaluator" |
|
|
|
def run( |
|
self, |
|
train_loader: DataLoader, |
|
val_loader: Optional[DataLoader], |
|
test_loader: Optional[DataLoader], |
|
train_dataset: Dataset, |
|
model: ImplicitronModelBase, |
|
optimizer: torch.optim.Optimizer, |
|
scheduler: Any, |
|
**kwargs, |
|
) -> None: |
|
raise NotImplementedError() |
|
|
|
def load_stats( |
|
self, |
|
log_vars: List[str], |
|
exp_dir: str, |
|
resume: bool = True, |
|
resume_epoch: int = -1, |
|
**kwargs, |
|
) -> Stats: |
|
raise NotImplementedError() |
|
|
|
|
|
@registry.register |
|
class ImplicitronTrainingLoop(TrainingLoopBase): |
|
""" |
|
Members: |
|
eval_only: If True, only run evaluation using the test dataloader. |
|
max_epochs: Train for this many epochs. Note that if the model was |
|
loaded from a checkpoint, we will restart training at the appropriate |
|
epoch and run for (max_epochs - checkpoint_epoch) epochs. |
|
store_checkpoints: If True, store model and optimizer state checkpoints. |
|
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal |
|
to this many epochs. |
|
test_interval: Evaluate on a test dataloader each `test_interval` epochs. |
|
test_when_finished: If True, evaluate on a test dataloader when training |
|
completes. |
|
validation_interval: Validate each `validation_interval` epochs. |
|
clip_grad: Optionally clip the gradient norms. |
|
If set to a value <=0.0, no clipping |
|
metric_print_interval: The batch interval at which the stats should be |
|
logged. |
|
visualize_interval: The batch interval at which the visualizations |
|
should be plotted |
|
visdom_env: The name of the Visdom environment to use for plotting. |
|
visdom_port: The Visdom port. |
|
visdom_server: Address of the Visdom server. |
|
""" |
|
|
|
|
|
eval_only: bool = False |
|
max_epochs: int = 1000 |
|
store_checkpoints: bool = True |
|
store_checkpoints_purge: int = 1 |
|
test_interval: int = -1 |
|
test_when_finished: bool = False |
|
validation_interval: int = 1 |
|
|
|
|
|
clip_grad: float = 0.0 |
|
|
|
|
|
metric_print_interval: int = 5 |
|
visualize_interval: int = 1000 |
|
visdom_env: str = "" |
|
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097)) |
|
visdom_server: str = "http://127.0.0.1" |
|
|
|
def __post_init__(self): |
|
run_auto_creation(self) |
|
|
|
|
|
|
|
def run( |
|
self, |
|
*, |
|
train_loader: DataLoader, |
|
val_loader: Optional[DataLoader], |
|
test_loader: Optional[DataLoader], |
|
train_dataset: Dataset, |
|
model: ImplicitronModelBase, |
|
optimizer: torch.optim.Optimizer, |
|
scheduler: Any, |
|
accelerator: Optional[Accelerator], |
|
device: torch.device, |
|
exp_dir: str, |
|
stats: Stats, |
|
seed: int, |
|
**kwargs, |
|
): |
|
""" |
|
Entry point to run the training and validation loops |
|
based on the specified config file. |
|
""" |
|
start_epoch = stats.epoch + 1 |
|
assert scheduler.last_epoch == stats.epoch + 1 |
|
assert scheduler.last_epoch == start_epoch |
|
|
|
|
|
if self.eval_only: |
|
if test_loader is not None: |
|
|
|
self.evaluator.run( |
|
dataloader=test_loader, |
|
device=device, |
|
dump_to_json=True, |
|
epoch=stats.epoch, |
|
exp_dir=exp_dir, |
|
model=model, |
|
) |
|
return |
|
else: |
|
raise ValueError( |
|
"Cannot evaluate and dump results to json, no test data provided." |
|
) |
|
|
|
|
|
for epoch in range(start_epoch, self.max_epochs): |
|
|
|
with stats: |
|
|
|
|
|
|
|
seed_all_random_engines(seed + epoch) |
|
|
|
cur_lr = float(scheduler.get_last_lr()[-1]) |
|
logger.debug(f"scheduler lr = {cur_lr:1.2e}") |
|
|
|
|
|
self._training_or_validation_epoch( |
|
accelerator=accelerator, |
|
device=device, |
|
epoch=epoch, |
|
loader=train_loader, |
|
model=model, |
|
optimizer=optimizer, |
|
stats=stats, |
|
validation=False, |
|
) |
|
|
|
|
|
if val_loader is not None and epoch % self.validation_interval == 0: |
|
self._training_or_validation_epoch( |
|
accelerator=accelerator, |
|
device=device, |
|
epoch=epoch, |
|
loader=val_loader, |
|
model=model, |
|
optimizer=optimizer, |
|
stats=stats, |
|
validation=True, |
|
) |
|
|
|
|
|
if ( |
|
test_loader is not None |
|
and self.test_interval > 0 |
|
and epoch % self.test_interval == 0 |
|
): |
|
self.evaluator.run( |
|
device=device, |
|
dataloader=test_loader, |
|
model=model, |
|
) |
|
|
|
assert stats.epoch == epoch, "inconsistent stats!" |
|
self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats) |
|
|
|
scheduler.step() |
|
new_lr = float(scheduler.get_last_lr()[-1]) |
|
if new_lr != cur_lr: |
|
logger.info(f"LR change! {cur_lr} -> {new_lr}") |
|
|
|
if self.test_when_finished: |
|
if test_loader is not None: |
|
self.evaluator.run( |
|
device=device, |
|
dump_to_json=True, |
|
epoch=stats.epoch, |
|
exp_dir=exp_dir, |
|
dataloader=test_loader, |
|
model=model, |
|
) |
|
else: |
|
raise ValueError( |
|
"Cannot evaluate and dump results to json, no test data provided." |
|
) |
|
|
|
def load_stats( |
|
self, |
|
log_vars: List[str], |
|
exp_dir: str, |
|
resume: bool = True, |
|
resume_epoch: int = -1, |
|
**kwargs, |
|
) -> Stats: |
|
""" |
|
Load Stats that correspond to the model's log_vars and resume_epoch. |
|
|
|
Args: |
|
log_vars: A list of variable names to log. Should be a subset of the |
|
`preds` returned by the forward function of the corresponding |
|
ImplicitronModelBase instance. |
|
exp_dir: Root experiment directory. |
|
resume: If False, do not load stats from the checkpoint speci- |
|
fied by resume and resume_epoch; instead, create a fresh stats object. |
|
|
|
stats: The stats structure (optionally loaded from checkpoint) |
|
""" |
|
|
|
visdom_env_charts = ( |
|
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts" |
|
) |
|
stats = Stats( |
|
|
|
list(log_vars), |
|
plot_file=os.path.join(exp_dir, "train_stats.pdf"), |
|
visdom_env=visdom_env_charts, |
|
visdom_server=self.visdom_server, |
|
visdom_port=self.visdom_port, |
|
) |
|
|
|
model_path = None |
|
if resume: |
|
if resume_epoch > 0: |
|
model_path = model_io.get_checkpoint(exp_dir, resume_epoch) |
|
if not os.path.isfile(model_path): |
|
raise FileNotFoundError( |
|
f"Cannot find stats from epoch {resume_epoch}." |
|
) |
|
else: |
|
model_path = model_io.find_last_checkpoint(exp_dir) |
|
|
|
if model_path is not None: |
|
stats_path = model_io.get_stats_path(model_path) |
|
stats_load = model_io.load_stats(stats_path) |
|
|
|
|
|
if resume: |
|
if stats_load is None: |
|
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n") |
|
last_epoch = model_io.parse_epoch_from_model_path(model_path) |
|
logger.info(f"Estimated resume epoch = {last_epoch}") |
|
|
|
|
|
for _ in range(last_epoch + 1): |
|
stats.new_epoch() |
|
assert last_epoch == stats.epoch |
|
else: |
|
logger.info(f"Found previous stats in {stats_path} -> resuming.") |
|
stats = stats_load |
|
|
|
|
|
stats.visdom_env = visdom_env_charts |
|
stats.visdom_server = self.visdom_server |
|
stats.visdom_port = self.visdom_port |
|
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf") |
|
stats.synchronize_logged_vars(log_vars) |
|
else: |
|
logger.info("Clearing stats") |
|
|
|
return stats |
|
|
|
def _training_or_validation_epoch( |
|
self, |
|
epoch: int, |
|
loader: DataLoader, |
|
model: ImplicitronModelBase, |
|
optimizer: torch.optim.Optimizer, |
|
stats: Stats, |
|
validation: bool, |
|
*, |
|
accelerator: Optional[Accelerator], |
|
bp_var: str = "objective", |
|
device: torch.device, |
|
**kwargs, |
|
) -> None: |
|
""" |
|
This is the main loop for training and evaluation including: |
|
model forward pass, loss computation, backward pass and visualization. |
|
|
|
Args: |
|
epoch: The index of the current epoch |
|
loader: The dataloader to use for the loop |
|
model: The model module optionally loaded from checkpoint |
|
optimizer: The optimizer module optionally loaded from checkpoint |
|
stats: The stats struct, also optionally loaded from checkpoint |
|
validation: If true, run the loop with the model in eval mode |
|
and skip the backward pass |
|
accelerator: An optional Accelerator instance. |
|
bp_var: The name of the key in the model output `preds` dict which |
|
should be used as the loss for the backward pass. |
|
device: The device on which to run the model. |
|
""" |
|
|
|
if validation: |
|
model.eval() |
|
trainmode = "val" |
|
else: |
|
model.train() |
|
trainmode = "train" |
|
|
|
t_start = time.time() |
|
|
|
|
|
visdom_env_imgs = stats.visdom_env + "_images_" + trainmode |
|
viz = vis_utils.get_visdom_connection( |
|
server=stats.visdom_server, |
|
port=stats.visdom_port, |
|
) |
|
|
|
|
|
n_batches = len(loader) |
|
for it, net_input in enumerate(loader): |
|
last_iter = it == n_batches - 1 |
|
|
|
|
|
net_input = net_input.to(device) |
|
|
|
|
|
if not validation: |
|
optimizer.zero_grad() |
|
preds = model( |
|
**{**net_input, "evaluation_mode": EvaluationMode.TRAINING} |
|
) |
|
else: |
|
with torch.no_grad(): |
|
preds = model( |
|
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION} |
|
) |
|
|
|
|
|
assert all(k not in preds for k in net_input.keys()) |
|
|
|
preds.update(net_input) |
|
|
|
|
|
stats.update(preds, time_start=t_start, stat_set=trainmode) |
|
|
|
assert stats.it[trainmode] == it, "inconsistent stat iteration number!" |
|
|
|
|
|
if it % self.metric_print_interval == 0 or last_iter: |
|
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches) |
|
logger.info(std_out) |
|
|
|
|
|
if ( |
|
(accelerator is None or accelerator.is_local_main_process) |
|
and self.visualize_interval > 0 |
|
and it % self.visualize_interval == 0 |
|
): |
|
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" |
|
if hasattr(model, "visualize"): |
|
model.visualize( |
|
viz, |
|
visdom_env_imgs, |
|
preds, |
|
prefix, |
|
) |
|
|
|
|
|
if not validation: |
|
loss = preds[bp_var] |
|
assert torch.isfinite(loss).all(), "Non-finite loss!" |
|
|
|
if accelerator is None: |
|
loss.backward() |
|
else: |
|
accelerator.backward(loss) |
|
if self.clip_grad > 0.0: |
|
|
|
total_norm = torch.nn.utils.clip_grad_norm( |
|
model.parameters(), self.clip_grad |
|
) |
|
if total_norm > self.clip_grad: |
|
logger.debug( |
|
f"Clipping gradient: {total_norm}" |
|
+ f" with coef {self.clip_grad / float(total_norm)}." |
|
) |
|
|
|
optimizer.step() |
|
|
|
def _checkpoint( |
|
self, |
|
accelerator: Optional[Accelerator], |
|
epoch: int, |
|
exp_dir: str, |
|
model: ImplicitronModelBase, |
|
optimizer: torch.optim.Optimizer, |
|
stats: Stats, |
|
): |
|
""" |
|
Save a model and its corresponding Stats object to a file, if |
|
`self.store_checkpoints` is True. In addition, if |
|
`self.store_checkpoints_purge` is True, remove any checkpoints older |
|
than `self.store_checkpoints_purge` epochs old. |
|
""" |
|
if self.store_checkpoints and ( |
|
accelerator is None or accelerator.is_local_main_process |
|
): |
|
if self.store_checkpoints_purge > 0: |
|
for prev_epoch in range(epoch - self.store_checkpoints_purge): |
|
model_io.purge_epoch(exp_dir, prev_epoch) |
|
outfile = model_io.get_checkpoint(exp_dir, epoch) |
|
unwrapped_model = ( |
|
model if accelerator is None else accelerator.unwrap_model(model) |
|
) |
|
model_io.safe_save_model( |
|
unwrapped_model, stats, outfile, optimizer=optimizer |
|
) |
|
|