World_Model / URSA /diffnext /engine /train_engine.py
BryanW's picture
Add files using upload-large-folder tool
b6ff324 verified
# ------------------------------------------------------------------------
# Copyright (c) 2024-present, BAAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ------------------------------------------------------------------------
"""Custom trainer focused on data parallelism specialization."""
import collections
import os
import shutil
import torch
from diffnext.engine import engine_utils
from diffnext.engine.model_ema import ModelEMA
from diffnext.pipelines.builder import build_pipeline
from diffnext.pipelines.builder import get_pipeline_path
from diffnext.utils import accelerate_utils
from diffnext.utils import profiler
from diffnext.utils.omegaconf_utils import config_to_class
from diffnext.utils.omegaconf_utils import config_to_object
class Trainer(object):
"""Schedule the iterative model training."""
def __init__(self, config, accelerator, logger):
"""Create a trainer instance."""
self.config, self.accelerator, self.logger = config, accelerator, logger
self.dtype = accelerate_utils.precision_to_dtype(config.training.mixed_precision)
self.train_dataloader = config_to_object(config.train_dataloader)
self.pipe_path = get_pipeline_path(**config.pipeline.paths)
self.pipe = build_pipeline(self.pipe_path, config_to_class(config.pipeline), self.dtype)
self.pipe = self.pipe.to(device=engine_utils.get_device(config.training.gpu_id))
self.ema = ModelEMA(self.pipe.model, **config.ema.params) if "ema" in config else None
self.model = self.pipe.configure_model(config, accelerator, logger)
param_groups = [_ for _ in self.model.parameters() if _.requires_grad]
if config.optimizer.get("param_groups", True):
param_groups = engine_utils.get_param_groups(self.model)
self.optimizer = config_to_object(config.optimizer, params=param_groups)
self.scheduler = config_to_object(config.lr_scheduler)
self.model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if config.training.get("sequence_parallel_size", 1) > 1:
if not hasattr(self.pipe.model, "configure_sequence_parallel"):
raise RuntimeError("Model does not support sequence parallelism.")
self.pipe.model.configure_sequence_parallel(self.model.seq_parallel_group)
self.metrics = collections.OrderedDict()
if self.ema and config.experiment.resume_iter > 0:
ckpt = config.experiment.resume_from_checkpoint
ema_ckpt = ckpt.replace("checkpoints", "ema_checkpoints")
ema_weights = os.path.join(ema_ckpt, config.model.name, "diffusion_pytorch_model.bin")
engine_utils.load_weights(self.ema.model, ema_weights)
@property
def global_step(self) -> int:
"""Return the global iteration step.
Returns:
int: The global step.
"""
return self.scheduler._step_count
def save(self):
"""Save the checkpoint of current iterative step."""
f = "checkpoint-{}/{}".format(self.global_step, self.config.model.name)
f = os.path.join(self.config.experiment.output_dir, "checkpoints", f)
if self.accelerator.is_main_process and not os.path.exists(f):
self.model.save_pretrained(f, safe_serialization=False)
self.logger.info("Wrote snapshot to: {:s}".format(f))
if self.ema is not None:
config_json = os.path.join(f, "config.json")
f = f.replace("checkpoints", "ema_checkpoints")
os.makedirs(f), shutil.copy(config_json, os.path.join(f, "config.json"))
f = os.path.join(f, "diffusion_pytorch_model.bin")
torch.save(self.ema.model.state_dict(), f)
def add_metrics(self, stats):
"""Add or update the metrics.
Args:
stats (Dict)
The current iteration stats.
"""
for k, v in stats["metrics"].items():
if k not in self.metrics:
self.metrics[k] = profiler.SmoothedValue()
self.metrics[k].update(v)
def log_metrics(self, stats):
"""Send metrics to available trackers.
Args:
stats (Dict)
The current iteration stats.
"""
iter_template = "Iteration %d, lr = %.8f, time = %.2fs"
metric_template = " " * 4 + "Train net output({}): {}"
[self.logger.info(iter_template % (stats["step"], stats["lr"], stats["time"]))]
[self.logger.info(metric_template.format(k, v)) for k, v in self.metrics.items()]
tracker_logs = dict((k, v.median) for k, v in self.metrics.items())
tracker_logs.update({"lr": stats["lr"], "time": stats["time"]})
self.accelerator.log(tracker_logs, step=stats["step"])
self.metrics.clear()
def run_model(self, inputs, metrics, accum_steps=1):
"""Run multiple model steps.
Args:
inputs (Dict)
The model inputs.
metrics (Dict)
The current iteration metrics.
accum_step (int, optional, defaults to 1)
The gradient accumulation steps.
"""
for _ in range(accum_steps):
inputs = inputs if inputs else self.train_dataloader.next()[0]
outputs, losses = self.model(inputs), []
for k, v in outputs.items():
if "loss" not in k and "metric" not in k:
continue
if isinstance(v, torch.Tensor) and v.requires_grad:
losses.append(v)
if k.startswith("metric/"): # Custom metrics.
metrics[k[len("metric/") :]] += float(v.mean()) / accum_steps
elif f"metric/{k}" not in outputs: # Legacy metrics.
metrics[k] += float(self.accelerator.gather(v).mean()) / accum_steps
losses = sum(losses[1:], losses[0])
self.accelerator.accumulate().__enter__()
self.accelerator.backward(losses)
def run_step(self, inputs, accum_steps=1) -> dict:
"""Run single iteration step.
Args:
inputs (Dict)
The model inputs.
accum_step (int, optional, defaults to 1)
The gradient accumulation steps.
Returns:
Dict: The current iteration stats.
"""
stats = {"step": self.global_step}
metrics = collections.defaultdict(float)
timer = profiler.Timer().tic()
stats["lr"] = self.scheduler.get_lr()
for group in self.optimizer.param_groups:
group["lr"] = stats["lr"] * group.get("lr_scale", 1.0)
self.run_model(inputs, metrics, accum_steps)
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
self.scheduler.step()
stats["time"] = timer.toc()
stats["metrics"] = collections.OrderedDict(sorted(metrics.items()))
return stats
def train_loop(self):
"""Training loop."""
timer = profiler.Timer()
max_steps = self.config.training.max_train_steps
accum_steps = self.config.training.gradient_accumulation_steps
log_every = self.config.experiment.log_every
save_every = self.config.experiment.save_every
data_every, inputs = self.config.training.get("data_every", -1), {}
self.scheduler._step_count = self.config.experiment.get("resume_iter", 0)
while self.global_step < max_steps:
if data_every >= 1 and self.global_step % data_every == 0:
inputs = self.train_dataloader.next()[0]
with timer.tic_and_toc():
stats = self.run_step(inputs, accum_steps)
self.add_metrics(stats)
if stats["step"] % log_every == 0:
self.log_metrics(stats)
if self.global_step % (10 * log_every) == 0:
self.logger.info(profiler.get_progress(timer, self.global_step, max_steps))
if self.ema and self.global_step % self.ema.update_every == 0:
self.ema.update(self.model)
if self.global_step % save_every == 0:
self.save()
stats["step"] = self.global_step
self.log_metrics({**stats, **{"step": self.global_step}})
self.accelerator.wait_for_everyone()
self.accelerator.end_training()