Spaces:
Running
on
Zero
Running
on
Zero
import time | |
from datetime import datetime | |
from pathlib import Path | |
from typing import Any | |
from omegaconf import OmegaConf | |
import ignite.distributed as idist | |
import numpy as np | |
import torch | |
from torch.utils.data import DataLoader | |
from ignite.contrib.engines import common | |
from ignite.contrib.handlers import TensorboardLogger | |
from ignite.engine import Engine, Events, EventsList | |
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine | |
from ignite.utils import manual_seed, setup_logger | |
from torch.cuda.amp import autocast, GradScaler | |
from scenedino.common.logging import event_list_from_config, global_step_fn, log_basic_info | |
from scenedino.common.io.configs import save_hydra_config | |
from scenedino.common.io.model import load_checkpoint | |
from scenedino.evaluation.wrapper import make_eval_fn | |
from scenedino.losses.base_loss import BaseLoss | |
from scenedino.training.handlers import ( | |
MetricLoggingHandler, | |
VisualizationHandler, | |
add_time_handlers, | |
) | |
from scenedino.common.array_operations import to | |
from scenedino.common.metrics import DictMeanMetric, MeanMetric, SegmentationMetric, ConcatenateMetric | |
from scenedino.visualization.vis_2d import tb_visualize | |
import optuna | |
def base_training(local_rank, config, get_dataflow, initialize, sweep_trial=None): | |
# ============================================ LOGGING AND OUTPUT SETUP ============================================ | |
# TODO: figure out rank | |
rank = ( | |
idist.get_rank() | |
) ## rank of the current process within a group of processes: each process could handle a unique subset of the data, based on its rank | |
manual_seed(config["seed"] + rank) | |
device = idist.device() | |
model_name = config["name"] | |
logger = setup_logger( | |
name=model_name, format="%(levelname)s: %(message)s" | |
) ## default | |
output_path = config["output"]["path"] | |
if rank == 0: | |
unique_id = config["output"].get( | |
"unique_id", datetime.now().strftime("%Y%m%d-%H%M%S") | |
) | |
folder_name = unique_id | |
# folder_name = f"{model_name}_backend-{idist.backend()}-{idist.get_world_size()}_{unique_id}" | |
output_path = Path(output_path) / folder_name | |
if not output_path.exists(): | |
output_path.mkdir(parents=True) | |
config["output"]["path"] = output_path.as_posix() | |
logger.info(f"Output path: {config['output']['path']}") | |
if "cuda" in device.type: | |
config["cuda device name"] = torch.cuda.get_device_name(local_rank) | |
log_basic_info(logger, config) | |
tb_logger = TensorboardLogger(log_dir=output_path) | |
# ================================================== DATASET SETUP ================================================= | |
# TODO: think about moving the dataset setup to the create validators and create trainer functions | |
train_loader, val_loaders = get_dataflow(config) | |
if isinstance(train_loader, tuple): | |
train_loader = train_loader[0] | |
if hasattr(train_loader, "dataset"): | |
val_loader_lengths = "\n".join( | |
[ | |
f"{name}: {len(val_loader.dataset)}" | |
for name, val_loader in val_loaders.items() | |
if hasattr(val_loader, "dataset") | |
] | |
) | |
logger.info( | |
f"Dataset lengths:\nTrain: {len(train_loader.dataset)}\n{val_loader_lengths}" | |
) | |
config["dataset"]["steps_per_epoch"] = len(train_loader) | |
# ============================================= MODEL AND OPTIMIZATION ============================================= | |
model, optimizer, criterion, lr_scheduler = initialize(config) | |
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters())}") | |
logger.info(f"Trainable model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") | |
# Create trainer for current task | |
trainer = create_trainer( | |
model, | |
optimizer, | |
criterion, | |
lr_scheduler, | |
train_loader.sampler if hasattr(train_loader, "sampler") else None, | |
config, | |
logger, | |
metrics={}, | |
) | |
if rank == 0: | |
tb_logger.attach( | |
trainer, | |
MetricLoggingHandler("train", optimizer), | |
Events.ITERATION_COMPLETED(every=config.get("log_every_iters", 1)), | |
) | |
# ========================================= EVALUTATION, AND VISUALIZATION ========================================= | |
validators: dict[str, tuple[Engine, EventsList]] = create_validators( | |
config, | |
model, | |
val_loaders, | |
criterion, | |
tb_logger, | |
trainer, | |
) | |
# NOTE: not super elegant as val_loaders has to have the same name but should work | |
def run_validation(name: str, validator: Engine): | |
def _run(engine: Engine): | |
epoch = trainer.state.epoch | |
state = validator.run(val_loaders[name]) | |
log_metrics(logger, epoch, state.times["COMPLETED"], name, state.metrics) | |
if sweep_trial is not None and name == "validation": | |
sweep_trial.report(trainer.state.best_metric, trainer.state.iteration) | |
if sweep_trial.should_prune(): | |
raise optuna.TrialPruned() | |
return _run | |
for name, validator in validators.items(): | |
trainer.add_event_handler(validator[1], run_validation(name, validator[0])) | |
# ================================================ SAVE FINAL CONFIG =============================================== | |
if rank == 0: | |
# Plot config to tensorboard | |
config_yaml = OmegaConf.to_yaml(config) | |
config_yaml = "".join("\t" + line for line in config_yaml.splitlines(True)) | |
tb_logger.writer.add_text("config", text_string=config_yaml, global_step=0) | |
save_hydra_config(output_path / "training_config.yaml", config, force=False) | |
# ================================================= TRAINING LOOP ================================================== | |
# In order to check training resuming we can stop training on a given iteration | |
if config.get("stop_iteration", None): | |
def _(): | |
logger.info(f"Stop training on {trainer.state.iteration} iteration") | |
trainer.terminate() | |
try: ## train_loader == models.bts.trainer_overfit.DataloaderDummy object | |
trainer.run(train_loader, | |
max_epochs=config["training"]["num_epochs"], | |
epoch_length=config["training"].get("epoch_length", None)) | |
except Exception as e: | |
logger.exception("") | |
raise e | |
if rank == 0: | |
tb_logger.close() | |
return trainer.state.best_metric | |
def log_metrics(logger, epoch, elapsed, tag, metrics): | |
metrics_output = "\n".join([f"\t{k}: {v}" for k, v in metrics.items()]) | |
logger.info( | |
f"\nEpoch {epoch} - Evaluation time (seconds): {elapsed:.2f} - {tag} metrics:\n {metrics_output}" | |
) | |
def create_trainer( | |
model: torch.nn.Module, | |
optimizer: torch.optim.Optimizer, | |
criterions: list[Any], | |
lr_scheduler, | |
train_sampler, | |
config, | |
logger, | |
metrics={}, | |
): | |
device = idist.device() | |
model = model.to(device) | |
# Setup Ignite trainer: | |
# - let's define training step | |
# - add other common handlers: | |
# - TerminateOnNan, | |
# - handler to setup learning rate scheduling, | |
# - ModelCheckpoint | |
# - RunningAverage` on `train_step` output | |
# - Two progress bars on epochs and optionally on iterations | |
with_amp = config["with_amp"] | |
gradient_accum_factor = config.get("gradient_accum_factor", 1) | |
scaler = GradScaler(enabled=with_amp) | |
def train_step(engine, data: dict): | |
if "t__get_item__" in data: | |
timing = {"t__get_item__": torch.mean(data["t__get_item__"]).item()} | |
else: | |
timing = {} | |
_start_time = time.time() | |
data = to(data, device) | |
timing["t_to_gpu"] = time.time() - _start_time | |
model.train() | |
model.validation_tag = None | |
_start_time = time.time() | |
with autocast(enabled=with_amp): | |
data = model(data) | |
timing["t_forward"] = time.time() - _start_time | |
loss_metrics = {} | |
if optimizer is not None: | |
_start_time = time.time() | |
overall_loss = torch.tensor(0.0, device=device) | |
for criterion in criterions: | |
losses = criterion(data) | |
names = criterion.get_loss_metric_names() | |
overall_loss += losses[names[0]] | |
loss_metrics.update({name: loss for name, loss in losses.items()}) | |
timing["t_loss"] = time.time() - _start_time | |
## make same scale for gradients. Note: it's not ignite built-in func. (c.f. https://wandb.ai/wandb_fc/tips/reports/How-To-Use-GradScaler-in-PyTorch--VmlldzoyMTY5MDA5) | |
_start_time = time.time() | |
# optimizer.zero_grad() | |
# scaler.scale(overall_loss).backward() | |
# scaler.step(optimizer) | |
# scaler.update() | |
# Gradient accumulation | |
overall_loss = overall_loss / gradient_accum_factor | |
scaler.scale(overall_loss).backward() | |
if engine.state.iteration % gradient_accum_factor == 0: | |
scaler.step(optimizer) | |
scaler.update() | |
optimizer.zero_grad() | |
timing["t_backward"] = time.time() - _start_time | |
return { | |
"output": data, | |
"loss_dict": loss_metrics, | |
"timings_dict": timing, | |
"metrics_dict": {}, | |
} | |
trainer = Engine(train_step) | |
trainer.logger = logger | |
for name, metric in metrics.items(): | |
metric.attach(trainer, name) | |
# TODO: maybe save only the network not the whole wrapper | |
# TODO: Make adaptable | |
to_save = { | |
"trainer": trainer, | |
"model": model, | |
# "optimizer": optimizer, | |
# "lr_scheduler": lr_scheduler, | |
} | |
common.setup_common_training_handlers( | |
trainer=trainer, | |
train_sampler=train_sampler, | |
to_save=to_save, | |
save_every_iters=config["training"]["checkpoint_every"], | |
save_handler=DiskSaver(config["output"]["path"], require_empty=False), | |
lr_scheduler=lr_scheduler, | |
output_names=None, | |
with_pbars=False, | |
clear_cuda_cache=False, | |
log_every_iters=config.get("log_every_iters", 100), | |
n_saved=1, | |
) | |
# NOTE: don't move to initialization, as to save is also needed here | |
if config["training"].get("resume_from", None): | |
ckpt_path = Path(config["training"]["resume_from"]) | |
logger.info(f"Resuming from checkpoint: {str(ckpt_path)}") | |
load_checkpoint(ckpt_path, to_save, strict=False) | |
if config["training"].get("from_pretrained", None): | |
ckpt_path = Path(config["training"]["from_pretrained"]) | |
logger.info(f"Pretrained from checkpoint: {str(ckpt_path)}") | |
to_save = {"model": to_save["model"]} | |
load_checkpoint(ckpt_path, to_save, strict=False) | |
if idist.get_rank() == 0: | |
common.ProgressBar(desc=f"Training", persist=False).attach(trainer) | |
return trainer | |
def create_validators( | |
config, | |
model: torch.nn.Module, | |
dataloaders: dict[str, DataLoader], | |
criterions: list[BaseLoss], | |
tb_logger: TensorboardLogger, | |
trainer: Engine, | |
) -> dict[str, tuple[Engine, EventsList]]: | |
# TODO: change model object to evaluator object that has a different ray sampler | |
with_amp = config["with_amp"] | |
device = idist.device() | |
def _create_validator( | |
tag: str, | |
validation_config, | |
) -> tuple[Engine, EventsList]: | |
# TODO: make eval functions configurable from config | |
metrics = {} | |
for metric_config in validation_config["metrics"]: | |
agg_type = metric_config.get("agg_type", None) | |
if agg_type == "unsup_seg": | |
metrics[metric_config["type"]] = SegmentationMetric( | |
metric_config["type"], make_eval_fn(model, metric_config), assign_pseudo=True | |
) | |
elif agg_type == "sup_seg": | |
metrics[metric_config["type"]] = SegmentationMetric( | |
metric_config["type"], make_eval_fn(model, metric_config), assign_pseudo=False | |
) | |
elif agg_type == "concat": | |
metrics[metric_config["type"]] = ConcatenateMetric( | |
metric_config["type"], make_eval_fn(model, metric_config) | |
) | |
else: | |
metrics[metric_config["type"]] = DictMeanMetric( | |
metric_config["type"], make_eval_fn(model, metric_config) | |
) | |
loss_during_validation = validation_config.get("log_loss", True) | |
if loss_during_validation: | |
metrics_loss = {} | |
for criterion in criterions: | |
metrics_loss.update( | |
{ | |
k: MeanMetric((lambda y: lambda x: x["loss_dict"][y])(k)) | |
for k in criterion.get_loss_metric_names() | |
} | |
) | |
eval_metrics = {**metrics, **metrics_loss} | |
else: | |
eval_metrics = metrics | |
def validation_step(engine: Engine, data): | |
model.eval() | |
model.validation_tag = tag | |
if "t__get_item__" in data: | |
timing = {"t__get_item__": torch.mean(data["t__get_item__"]).item()} | |
else: | |
timing = {} | |
data = to(data, device) | |
with autocast(enabled=with_amp): | |
data = model(data) | |
overall_loss = torch.tensor(0.0, device=device) | |
loss_metrics = {} | |
if loss_during_validation: | |
for criterion in criterions: | |
losses = criterion(data) | |
names = criterion.get_loss_metric_names() | |
overall_loss += losses[names[0]] | |
loss_metrics.update({name: loss for name, loss in losses.items()}) | |
else: | |
loss_metrics = {} | |
return { | |
"output": data, | |
"loss_dict": loss_metrics, | |
"timings_dict": timing, | |
"metrics_dict": {}, | |
} | |
validator = Engine(validation_step) | |
add_time_handlers(validator) | |
# ADD METRICS | |
for name, metric in eval_metrics.items(): | |
metric.attach(validator, name) | |
# ADD LOGGING HANDLER | |
# TODO: split up handlers | |
tb_logger.attach( | |
validator, | |
MetricLoggingHandler( | |
tag, | |
log_loss=False, | |
global_step_transform=global_step_fn( | |
trainer, validation_config["global_step"] | |
), | |
), | |
Events.EPOCH_COMPLETED, | |
) | |
# ADD VISUALIZATION HANDLER | |
if validation_config.get("visualize", None): | |
visualize = tb_visualize( | |
(model.renderer.net if hasattr(model, "renderer") else model.module.renderer.net), | |
dataloaders[tag].dataset.dataset, | |
validation_config["visualize"], | |
) | |
def vis_wrapper(*args, **kwargs): | |
with autocast(enabled=with_amp): | |
return visualize(*args, **kwargs) | |
tb_logger.attach( | |
validator, | |
VisualizationHandler( | |
tag=tag, | |
visualizer=vis_wrapper, | |
global_step_transform=global_step_fn( | |
trainer, validation_config["global_step"] | |
), | |
), | |
Events.ITERATION_COMPLETED(every=1), | |
) | |
if "save_best" in validation_config: | |
save_best_config = validation_config["save_best"] | |
metric_name = save_best_config["metric"] | |
sign = save_best_config.get("sign", 1.0) | |
update_model = save_best_config.get("update_model", False) | |
dry_run = save_best_config.get("dry_run", False) | |
best_model_handler = Checkpoint( | |
{"model": model}, | |
# NOTE: fixes a problem with log_dir or logdir | |
DiskSaver(Path(config["output"]["path"]), require_empty=False), | |
# DiskSaver(tb_logger.writer.log_dir, require_empty=False), | |
filename_prefix=f"{metric_name}_best", | |
n_saved=1, | |
global_step_transform=global_step_from_engine(trainer), | |
score_name=metric_name, | |
score_function=Checkpoint.get_default_score_fn( | |
metric_name, score_sign=sign | |
), | |
) | |
def event_handler(engine): | |
if update_model: | |
model.update_model_eval(engine.state.metrics) | |
if not dry_run: | |
best_model_handler(engine) | |
trainer.state.best_metric = best_model_handler._saved[0].priority | |
validator.add_event_handler(Events.COMPLETED, event_handler) | |
if idist.get_rank() == 0 and (not validation_config.get("with_clearml", False)): | |
common.ProgressBar(desc=f"Evaluation ({tag})", persist=False).attach( | |
validator | |
) | |
return validator, event_list_from_config(validation_config["events"]) | |
return { | |
name: _create_validator(name, config) | |
for name, config in config["validation"].items() | |
} | |