File size: 14,949 Bytes
74e8f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 |
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training loop with flexible/schedulable settings."""
# pylint: disable=consider-using-from-import
import functools
import importlib
import multiprocessing.pool
import os
from absl import app
from absl import flags
from absl import logging
import big_vision.evaluators.common as eval_common
import big_vision.input_pipeline as input_pipeline
import big_vision.optax as bv_optax
import big_vision.trainers.proj.flexi.common as flexi
import big_vision.utils as u
from clu import parameter_overview
import flax
import jax
import jax.numpy as jnp
from ml_collections import config_flags
import numpy as np
import optax
import tensorflow as tf
from tensorflow.io import gfile
# pylint: disable=logging-fstring-interpolation
config_flags.DEFINE_config_file(
"config", None, "Training configuration.", lock_config=True)
flags.DEFINE_string("workdir", default=None, help="Work unit directory.")
flags.DEFINE_boolean("cleanup", default=False,
help="Delete workdir (only) after successful completion.")
# Adds jax flags to the program.
jax.config.parse_flags_with_absl()
def main(argv):
del argv
tf.config.experimental.set_visible_devices([], "GPU")
config = flags.FLAGS.config
workdir = flags.FLAGS.workdir
logging.info(
f"\u001b[33mHello from process {jax.process_index()} holding "
f"{jax.local_device_count()}/{jax.device_count()} devices and "
f"writing to workdir {workdir}.\u001b[0m")
save_ckpt_path = None
if workdir: # Always create if requested, even if we may not write into it.
gfile.makedirs(workdir)
save_ckpt_path = os.path.join(workdir, "checkpoint.npz")
# The pool is used to perform misc operations such as logging in async way.
pool = multiprocessing.pool.ThreadPool()
# Here we register preprocessing ops from modules listed on `pp_modules`.
for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]):
importlib.import_module(f"big_vision.pp.{m}")
# This seed makes the Jax part of things (like model init) deterministic.
# However, full training still won't be deterministic, for example due to the
# tf.data pipeline not being deterministic even if we would set TF seed.
# See (internal link) for a fun read on what it takes.
rng = jax.random.PRNGKey(config.get("seed", 0))
# These functions do more stuff internally, for OSS release we mock them by
# trivial alternatives in order to minize disruptions in the code.
xid, wid = -1, -1
fillin = lambda s: s
def info(s, *a):
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
def write_note(note):
if jax.process_index() == 0:
info("%s", note)
write_note("Initializing...")
batch_size = config.input.batch_size
if batch_size % jax.device_count() != 0:
raise ValueError(f"Batch size ({batch_size}) must "
f"be divisible by device number ({jax.device_count()})")
info("Global batch size %d on %d hosts results in %d local batch size. With "
"%d dev per host (%d dev total), that's a %d per-device batch size.",
batch_size, jax.process_count(), batch_size // jax.process_count(),
jax.local_device_count(), jax.device_count(),
batch_size // jax.device_count())
# First thing after above sanity checks, so we can log "start" ticks.
mw = u.BigVisionMetricWriter(xid, wid, workdir, config)
write_note("Initializing train dataset...")
train_ds, ntrain_img = input_pipeline.training(config.input)
# Start prefetching already.
n_prefetch = config.get("prefetch_to_device", 1)
train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch)
total_steps = u.steps("total", config, ntrain_img, batch_size)
def get_steps(name, default=ValueError, cfg=config):
return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default)
u.chrono.inform(total_steps=total_steps, global_bs=batch_size,
steps_per_epoch=ntrain_img / batch_size,
measure=mw.measure, write_note=write_note)
info("Running for %d steps, that means %f epochs",
total_steps, total_steps * batch_size / ntrain_img)
write_note(f"Initializing {config.model_name} model...")
model_mod = importlib.import_module(f"big_vision.models.{config.model_name}")
model = model_mod.Model(
num_classes=config.num_classes, **config.get("model", {}))
# We want all parameters to be created in host RAM, not on any device, they'll
# be sent there later as needed, otherwise we already encountered two
# situations where we allocate them twice.
@functools.partial(jax.jit, backend="cpu")
def init(rng):
shape = tuple(train_ds.element_spec["image"].shape[1:])
bs = batch_size // jax.device_count()
dummy_input = jnp.zeros((bs,) + shape, jnp.float32)
params = flax.core.unfreeze(model.init(rng, dummy_input))["params"]
# Set bias in the head to a low value, such that loss is small initially.
if "init_head_bias" in config:
params["head"]["bias"] = jnp.full_like(params["head"]["bias"],
config["init_head_bias"])
return params
rng, rng_init = jax.random.split(rng)
with u.chrono.log_timing("z/secs/init"):
params_cpu = init(rng_init)
if jax.process_index() == 0:
num_params = sum(p.size for p in jax.tree_leaves(params_cpu))
parameter_overview.log_parameter_overview(params_cpu, msg="init params")
mw.measure("num_params", num_params)
write_note(f"Initializing {config.optax_name} optimizer...")
tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict(
total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img))
# We jit this, such that the arrays are created on the CPU, not device[0].
opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu)
sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns]
flexi_argnames = sorted(config.flexi)
@functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1),
static_broadcasted_argnums=tuple(range(5, 5 + len(flexi_argnames))))
def update_fn(params, opt, rng, images, labels, *args):
"""Update step."""
measurements = {}
if config.get("mixup") and config.mixup.p:
rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup)
# Get device-specific loss rng.
rng, rng_model = jax.random.split(rng, 2)
rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch"))
def loss_fn(params, images, labels):
logits, _ = model.apply(
{"params": params}, images,
train=True, rngs={"dropout": rng_model_local},
**dict(zip(flexi_argnames, args)))
return getattr(u, config.get("loss", "sigmoid_xent"))(
logits=logits, labels=labels)
l, grads = jax.value_and_grad(loss_fn)(params, images, labels)
l, grads = jax.lax.pmean((l, grads), axis_name="batch")
updates, opt = tx.update(grads, opt, params)
params = optax.apply_updates(params, updates)
gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.))
measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs]))
ps = jax.tree_leaves(params)
measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
us = jax.tree_leaves(updates)
measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))
return params, opt, rng, l, measurements
# We do not jit/pmap this function, because it is passed to evaluator that
# does it later. We output as many intermediate tensors as possible for
# maximal flexibility. Later `jit` will prune out things that are not needed.
def predict_fn(params, image, **flexi_kw):
logits, out = model.apply({"params": params}, image, **flexi_kw)
return logits, out
# Decide how to initialize training. The order is important.
# 1. Always resumes from the existing checkpoint, e.g. resumes a finetune job.
# 2. Resume from a previous checkpoint, e.g. start a cooldown training job.
# 3. Initialize model from something, e,g, start a fine-tuning job.
# 4. Train from scratch.
resume_ckpt_path = None
if save_ckpt_path and gfile.exists(save_ckpt_path):
resume_ckpt_path = save_ckpt_path
elif config.get("resume"):
resume_ckpt_path = fillin(config.resume)
if resume_ckpt_path:
write_note("Resume training from checkpoint...")
checkpoint = {
"params": params_cpu,
"opt": opt_cpu,
"chrono": u.chrono.save(),
}
checkpoint_tree = jax.tree_structure(checkpoint)
loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree)
# bfloat16 type gets lost when data is saved to disk, so we recover it.
checkpoint = jax.tree_map(u.recover_dtype, loaded)
params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"]
u.chrono.load(checkpoint["chrono"])
elif config.get("model_init"):
write_note(f"Initialize model from {config.model_init}...")
params_cpu = model_mod.load(
params_cpu, config.model_init, config.get("model"),
**config.get("model_load", {}))
if jax.process_index() == 0:
parameter_overview.log_parameter_overview(
params_cpu, msg="restored params")
write_note("Kicking off misc stuff...")
first_step = bv_optax.get_count(opt_cpu)
u.chrono.inform(first_step=first_step)
prof = None # Keeps track of start/stop of profiler state.
write_note(f"Replicating...\n{u.chrono.note}")
params_repl = flax.jax_utils.replicate(params_cpu)
opt_repl = flax.jax_utils.replicate(opt_cpu)
@functools.cache
def evaluators():
return eval_common.from_config(
config, flexi.mkpredictfns(predict_fn, config.flexi, "predict_{x}"),
lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"),
lambda key, cfg: get_steps(key, default=None, cfg=cfg),
)
rng, rng_loop = jax.random.split(rng, 2)
rngs_loop = flax.jax_utils.replicate(rng_loop)
ckpt_writer = None
write_note(f"First step compilations...\n{u.chrono.note}")
# Note that training can be pre-empted during the final evaluation (i.e.
# just after the final checkpoint has been written to disc), in which case we
# want to run the evals.
if first_step in (total_steps, 0):
mw.step_start(first_step)
for (name, evaluator, _, prefix) in evaluators():
if config.evals[name].get("skip_first") and first_step != total_steps:
continue
write_note(f"{name} evaluation...\n{u.chrono.note}")
with u.chrono.log_timing(f"z/secs/eval/{name}"):
for key, value in evaluator.run(params_repl):
mw.measure(f"{prefix}{key}", value)
# Using a python integer for step here, because opt.state.step is allocated
# on TPU during replication.
for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter):
mw.step_start(step)
np_rng = flexi.mkrng(xm_xp.id, xm_wu.id, step)
flexi_args = [
flexi.choice(config.flexi[n].v, config.flexi[n].p, np_rng)
for n in flexi_argnames
]
with jax.profiler.StepTraceAnnotation("train_step", step_num=step):
with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1):
params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn(
params_repl, opt_repl, rngs_loop, batch["image"], batch["labels"],
*flexi_args)
# On the first host, let's always profile a handful of early steps.
if jax.process_index() == 0:
prof = u.startstop_prof(prof, step, first_step, get_steps("log_training"))
# Report training progress
if (u.itstime(step, get_steps("log_training"), total_steps, host=0)
or u.chrono.warmup and jax.process_index() == 0):
for i, sched_fn_cpu in enumerate(sched_fns_cpu):
mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1))
l = mw.measure("training_loss", loss_value[0])
for name, value in measurements.items():
mw.measure(name, value[0])
u.chrono.tick(step)
if not np.isfinite(l):
raise RuntimeError(f"The loss became nan or inf somewhere within steps "
f"[{step - get_steps('log_training')}, {step}]")
# Checkpoint saving
if (save_ckpt_path and
(u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or
u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))):
u.chrono.pause(wait_for=(params_repl, opt_repl))
u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1))
# We need to transfer the weights over now or else we risk keeping them
# alive while they'll be updated in a future step, creating hard to debug
# memory errors (see (internal link)). Also, takes device 0's params only.
params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl)
opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl)
# Check whether we want to keep a copy of the current checkpoint.
copy_step = None
if u.itstime(step, get_steps("keep_ckpt", None), total_steps):
copy_step = step
ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": u.chrono.save()}
ckpt_writer = pool.apply_async(
u.save_checkpoint, (ckpt, save_ckpt_path, copy_step))
u.chrono.resume()
for (name, evaluator, log_steps, prefix) in evaluators():
if u.itstime(step, log_steps, total_steps, first=False, last=True):
u.chrono.pause(wait_for=params_repl)
u.chrono.tick(step) # Record things like epoch number, core hours etc.
write_note(f"{name} evaluation...\n{u.chrono.note}")
with u.chrono.log_timing(f"z/secs/eval/{name}"):
for key, value in evaluator.run(params_repl):
mw.measure(f"{prefix}{key}", value)
u.chrono.resume()
mw.step_end()
# Always give a chance to stop the profiler, no matter how things ended.
# TODO: can we also do this when dying of an exception like OOM?
if jax.process_index() == 0 and prof is not None:
u.startstop_prof(prof)
# Last note needs to happen before the pool's closed =)
write_note(f"Done!\n{u.chrono.note}")
pool.close()
pool.join()
mw.close()
# Make sure all hosts stay up until the end of main.
u.sync()
u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info)
if __name__ == "__main__":
app.run(main)
|