| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Utils very specific to this project, not generic.""" |
| |
|
| | import collections |
| | import contextlib |
| | import dataclasses |
| | import functools |
| | import io |
| | import json |
| | import multiprocessing |
| | import multiprocessing.pool |
| | import os |
| | import re |
| | import sys |
| | import time |
| | from typing import Mapping |
| |
|
| | from absl import flags |
| | from absl import logging |
| | from big_vision.pp import registry as pp_registry |
| | import einops |
| | import flax |
| | import flax.jax_utils as flax_utils |
| | import jax |
| | from jax.experimental import mesh_utils |
| | from jax.experimental.array_serialization import serialization as array_serial |
| | import jax.numpy as jnp |
| | import ml_collections as mlc |
| | import numpy as np |
| |
|
| | import tensorflow.io.gfile as gfile |
| |
|
| |
|
| | Registry = pp_registry.Registry |
| |
|
| |
|
| | |
| |
|
| |
|
| | def pad_shard_unpad(wrapped, static_argnums=(0,), static_argnames=()): |
| | """Wraps a function with code that pads, shards, then un-shards, un-pads. |
| | |
| | Args: |
| | wrapped: the function to be wrapped. Signature is `params, *args, *kwargs`. |
| | static_argnums: indices of arguments to `wrapped` that should _not_ be |
| | padded and sharded, but instead be forwarded as-is. The default is (0,) |
| | because by far the most common use-case is to pass `params` first. |
| | static_argnames: names of kwargs to `wrapped` that should _not_ be padded |
| | and sharded, but instead be forwarded as-is. |
| | |
| | Returns: |
| | A new function that pads and shards its arguments before passing them to |
| | the wrapped function, and un-shards and un-pads the returned pytree. |
| | |
| | This is useful for calling a pmap'ed function with inputs that aren't |
| | divisible by the number of devices. A typical use is: |
| | @pad_shard_unpad |
| | @jax.pmap |
| | def forward(params, x): ... |
| | |
| | Notes: |
| | The padding is done in host-memory before being passed to the function, and |
| | the values returned by the function are transferred back to host memory. |
| | |
| | The returned function is augmented with a new keyword-only argument |
| | `min_device_batch` that, if specified, forces padding inputs to at least |
| | this size per device. This can be useful to avoid recompiles for the last |
| | batch and reduce memory fragmentation. |
| | """ |
| |
|
| | def pad_shard_unpad_wrapper(*args, min_device_batch=None, **kw): |
| | d = jax.local_device_count() |
| |
|
| | |
| | def get_bs(x): |
| | batch_sizes = jax.tree.map(lambda y: y.shape[0], x) |
| | return jax.tree.flatten(batch_sizes)[0] |
| |
|
| | bs_a = [get_bs(a) for i, a in enumerate(args) if i not in static_argnums] |
| | bs_kw = [get_bs(v) for k, v in kw.items() if k not in static_argnames] |
| | bs = set([n for b in (bs_a + bs_kw) for n in b]) |
| | assert len(bs) == 1, f"Inconsistent batch-sizes: {bs}" |
| | b = bs.pop() |
| |
|
| | def pad(x): |
| | _, *shape = x.shape |
| | db, rest = divmod(b, d) |
| | if rest: |
| | x = np.concatenate([x, np.zeros((d - rest, *shape), x.dtype)], axis=0) |
| | db += 1 |
| | if min_device_batch and db < min_device_batch: |
| | x = np.concatenate( |
| | [x, np.zeros((d * (min_device_batch - db), *shape), x.dtype)]) |
| | db = min_device_batch |
| | return x.reshape(d, db, *shape) |
| |
|
| | def maybe_pad(x, actually_pad=True): |
| | if not actually_pad: return x |
| | return jax.tree.map(pad, x) |
| |
|
| | args = [maybe_pad(a, i not in static_argnums) for i, a in enumerate(args)] |
| | kw = {k: maybe_pad(v, k not in static_argnames) for k, v in kw.items()} |
| | out = wrapped(*args, **kw) |
| |
|
| | def unpad(x): |
| | |
| | return einops.rearrange(jax.device_get(x), "d b ... -> (d b) ...")[:b] |
| | return jax.tree.map(unpad, out) |
| |
|
| | return pad_shard_unpad_wrapper |
| |
|
| |
|
| | def onehot(labels, num_classes, on_value=1.0, off_value=0.0): |
| | x = (labels[..., None] == jnp.arange(num_classes)[None]) |
| | x = jax.lax.select(x, jnp.full(x.shape, on_value), |
| | jnp.full(x.shape, off_value)) |
| | return x.astype(jnp.float32) |
| |
|
| |
|
| | def npload(fname): |
| | """Loads `fname` and returns an np.ndarray or dict thereof.""" |
| | |
| | if os.path.exists(fname): |
| | loaded = np.load(fname, allow_pickle=False) |
| | else: |
| | |
| | with gfile.GFile(fname, "rb") as f: |
| | data = f.read() |
| | loaded = np.load(io.BytesIO(data), allow_pickle=False) |
| |
|
| | |
| | if isinstance(loaded, np.ndarray): |
| | return loaded |
| | else: |
| | return dict(loaded) |
| |
|
| |
|
| | def load_checkpoint_np(npz, tree=None): |
| | """Loads a jax pytree from a npz file. |
| | |
| | Args: |
| | npz: Either path to the checkpoint file (.npz), or a dict-like. |
| | tree: deprecated, use None. |
| | Bwd-compat for old format that only stored values: the pytree structure. |
| | |
| | Returns: |
| | A pytree that is the checkpoint. |
| | """ |
| | if isinstance(npz, str): |
| | npz = npload(npz) |
| | keys, values = zip(*list(npz.items())) |
| | if tree: |
| | checkpoint = tree.unflatten(values) |
| | else: |
| | checkpoint = recover_tree(keys, values) |
| | return checkpoint |
| |
|
| |
|
| | def load_params(ckpt, **kw): |
| | """Loads the parameters of a big_vision checkpoint, both old or new format. |
| | |
| | Args: |
| | ckpt: Path to the checkpoint (.npz, .ts) or dict-like. |
| | **kw: forwarded to the underlying load function (_np or _ts). |
| | |
| | Returns: |
| | A pytree that is the checkpoint, potentially sharded. |
| | |
| | Notes: |
| | The `ckpt` string can contain an colon-separated "submodel" indicator, like |
| | `img` in the example `/path/to/file.npz:img`. |
| | This is used to load sub-parts of a model, for example the image load the |
| | image encoder out of a two_tower (SigLIP) checkpoint, or distillation. |
| | This way, ANY model that uses this function can load itself from a |
| | checkpoint that contains multiple sub-models. |
| | """ |
| | key = None |
| |
|
| | if isinstance(ckpt, str): |
| | |
| | |
| | |
| | if match := re.match(r"^(.*?/.*?)(?::([\w/]+))?$", ckpt): |
| | ckpt, key = match.groups() |
| | else: |
| | raise ValueError(f"Weird ckpt path: {ckpt} ; Maybe prepend ./ ?") |
| |
|
| | |
| | |
| | if ".npz" in ckpt: |
| | checkpoint = load_checkpoint_np(ckpt, **kw) |
| | checkpoint = jax.tree.map(recover_dtype, checkpoint) |
| | if "params" in checkpoint: |
| | |
| | params = checkpoint["params"] |
| | elif "opt" in checkpoint: |
| | |
| | params = checkpoint["opt"]["target"] |
| | else: |
| | |
| | params = checkpoint |
| | else: |
| | |
| | |
| | regex = f"params/{key}($|/.*)" if key else "params/.*" |
| | assert "regex" not in kw, "For a custom regex, use tsload directly." |
| | kw["regex"] = regex |
| | checkpoint = load_checkpoint_ts(ckpt, **kw) |
| | params = checkpoint["params"] |
| |
|
| | if key is not None: |
| | params = tree_get(params, key) |
| |
|
| | return params |
| |
|
| |
|
| | def prefetch_scalar(it, nprefetch=1, devices=None): |
| | n_loc_dev = len(devices) if devices else jax.local_device_count() |
| | repl_iter = (np.ones(n_loc_dev) * i for i in it) |
| | return flax_utils.prefetch_to_device(repl_iter, nprefetch, devices) |
| |
|
| |
|
| | def sigmoid_xent(*, logits, labels, reduction=True): |
| | |
| | |
| | |
| | log_p = jax.nn.log_sigmoid(logits) |
| | log_not_p = jax.nn.log_sigmoid(-logits) |
| | nll = -jnp.sum(labels * log_p + (1. - labels) * log_not_p, axis=-1) |
| | return jnp.mean(nll) if reduction else nll |
| |
|
| |
|
| | def bidirectional_contrastive_loss(zimg, ztxt, t, mask=None, reduction=False): |
| | """Bidirectional contrastive loss (e.g. for contrastive trainer/evaluator).""" |
| | |
| | logits = jnp.dot(zimg, ztxt.T) * t |
| |
|
| | if mask is not None: |
| | |
| | |
| | exclude = jnp.logical_not(mask) |
| | exclude = jnp.logical_or(exclude[:, None], exclude[None, :]) |
| | logits = jnp.where(exclude, -jnp.inf, logits) |
| |
|
| | |
| | l1 = -jnp.diag(jax.nn.log_softmax(logits, axis=1)) |
| | l2 = -jnp.diag(jax.nn.log_softmax(logits, axis=0)) |
| | l = 0.5 * (l1 + l2) |
| |
|
| | if mask is not None: |
| | l = jnp.where(mask, l, 0) |
| |
|
| | redux = jnp.mean if reduction else lambda x: x |
| | if reduction and mask is not None: |
| | redux = lambda x: jnp.sum(x * mask) / (jnp.sum(mask) + 1e-8) |
| |
|
| | |
| | return redux(l), { |
| | "ncorrect": redux(jnp.argmax(logits, axis=1) == jnp.arange(len(logits))), |
| | } |
| |
|
| |
|
| | def softmax_xent(*, logits, labels, reduction=True, kl=False, axis=-1): |
| | log_p = jax.nn.log_softmax(logits, axis=axis) |
| | nll = -jnp.sum(labels * log_p, axis=axis) |
| | if kl: |
| | nll += jnp.sum(labels * jnp.log(jnp.clip(labels, 1e-8)), axis=axis) |
| | return jnp.mean(nll) if reduction else nll |
| |
|
| |
|
| | def weighted_softmax_xent(*, |
| | logits, |
| | labels, |
| | reduction=True, |
| | weights=None, |
| | label_smoothing=0.0, |
| | normalize=True): |
| | """Compute weighted cross entropy. |
| | |
| | Args: |
| | logits: [batch, length, num_classes] float array. |
| | labels: categorical targets [batch, length] int array. |
| | reduction: reduce across batch dim. |
| | weights: None or array of shape [batch, length]. |
| | label_smoothing: label smoothing constant, used to determine the on and off |
| | values. |
| | normalize: normalize each "sentence" loss by the number of tokens in it. |
| | |
| | Returns: |
| | Tuple of scalar loss and batch normalizing factor. |
| | """ |
| | if logits.ndim != labels.ndim + 1: |
| | raise ValueError("Incorrect shapes. Got shape %s logits and %s targets" % |
| | (str(logits.shape), str(labels.shape))) |
| | vocab_size = logits.shape[-1] |
| | confidence = 1.0 - label_smoothing |
| | low_confidence = (1.0 - confidence) / (vocab_size - 1) |
| | soft_targets = onehot( |
| | labels, vocab_size, on_value=confidence, off_value=low_confidence) |
| |
|
| | loss = -jnp.sum(soft_targets * jax.nn.log_softmax(logits), axis=-1) |
| |
|
| | normalizing_factor = labels.shape[1] |
| | if weights is not None: |
| | loss = loss * weights |
| | normalizing_factor = jnp.clip(weights.sum(axis=1), 2e-38) |
| |
|
| | loss = loss.sum(axis=1) |
| | if normalize: |
| | loss = loss / normalizing_factor |
| |
|
| | return loss.mean() if reduction else loss |
| |
|
| |
|
| | def accumulate_gradient(loss_and_grad_fn, params, images, labels, accum_steps): |
| | """Accumulate gradient over multiple steps to save on memory.""" |
| | |
| | if accum_steps and accum_steps > 1: |
| | assert images.shape[0] % accum_steps == 0, ( |
| | f"Bad accum_steps {accum_steps} for batch size {images.shape[0]}") |
| | step_size = images.shape[0] // accum_steps |
| | l, g = loss_and_grad_fn(params, images[:step_size], labels[:step_size]) |
| | def acc_grad_and_loss(i, l_and_g): |
| | imgs = jax.lax.dynamic_slice(images, (i*step_size, 0, 0, 0), |
| | (step_size,) + images.shape[1:]) |
| | lbls = jax.lax.dynamic_slice(labels, (i*step_size, 0), |
| | (step_size, labels.shape[1])) |
| | li, gi = loss_and_grad_fn(params, imgs, lbls) |
| | l, g = l_and_g |
| | return (l + li, jax.tree.map(lambda x, y: x + y, g, gi)) |
| | l, g = jax.lax.fori_loop(1, accum_steps, acc_grad_and_loss, (l, g)) |
| | return jax.tree.map(lambda x: x / accum_steps, (l, g)) |
| | else: |
| | return loss_and_grad_fn(params, images, labels) |
| |
|
| |
|
| | def itstime(step, every_n_steps, total_steps, host=None, last=True, first=True, |
| | drop_close_to_last=0.25): |
| | """Returns True if it's time to execute an action. |
| | |
| | Args: |
| | step: the current step representing "now". |
| | every_n_steps: the action should run every this many steps. |
| | total_steps: the step number of the last step of training. |
| | host: host number. If provided, only run if we are this process. |
| | last: whether to run on the last step or not. |
| | first: whether to run on the first step or not. |
| | drop_close_to_last: if a step would run, but is this close (in terms of |
| | fraction of every_n_step) to the last one, skip. |
| | |
| | Returns: |
| | True if the action should be executed, False if not. |
| | """ |
| |
|
| | |
| | |
| | |
| | close_to_last = False |
| | if drop_close_to_last and every_n_steps: |
| | close_to_last = abs(step - total_steps) < drop_close_to_last * every_n_steps |
| |
|
| | is_host = host is None or jax.process_index() == host |
| | is_step = every_n_steps and (step % every_n_steps == 0) and not close_to_last |
| | is_last = every_n_steps and step == total_steps |
| | is_first = every_n_steps and step == 1 |
| | return is_host and (is_step or (last and is_last) or (first and is_first)) |
| |
|
| |
|
| | def checkpointing_timeout(writer, timeout): |
| | |
| | if writer is not None: |
| | try: |
| | |
| | |
| | writer.get(timeout=timeout) |
| | except multiprocessing.TimeoutError as e: |
| | raise TimeoutError( |
| | "Checkpoint writing seems to be a bottleneck. Make sure you do " |
| | "not do something wrong, like writing checkpoints to a distant " |
| | "cell. In a case you are OK with checkpoint writing being a " |
| | "bottleneck, you can configure `ckpt_timeout` parameter") from e |
| |
|
| |
|
| | def hms(s): |
| | """Format time in hours/minutes/seconds.""" |
| | if s < 60: |
| | return f"{s:.0f}s" |
| | m, s = divmod(s, 60) |
| | if m < 60: |
| | return f"{m:.0f}m{s:.0f}s" |
| | h, m = divmod(m, 60) |
| | if h < 25: |
| | return f"{h:.0f}h{m:.0f}m" |
| | d, h = divmod(h, 24) |
| | return f"{d:.0f}d{h:.0f}h{m:.0f}m" |
| |
|
| |
|
| | class Chrono: |
| | """Measures time and reports progress, hyper-specific to our train loops. |
| | |
| | Some concepts: |
| | 1. This differentiates between three "types" of time: |
| | - training time: the time spent on actual training (fprop/bprop/update) |
| | - program time: overall time the program runs, including all overheads |
| | - pause time: the chronometer can be paused (eg during evals). |
| | 2. This handles a "warmup": the first step is skipped for training time |
| | purposes, as it includes significant compilation overheads, which distort |
| | estimates. |
| | 3. `accum`ulates (i.e. integrates) timings, and save/load them across |
| | restarts. |
| | """ |
| |
|
| | def __init__(self): |
| | self._timing_history = collections.defaultdict(list) |
| | self._measure = None |
| | self._write_note = None |
| |
|
| | self.program_start_time = time.monotonic() |
| | self.train_start_time = None |
| | self.train_start_step = None |
| |
|
| | self.prev_time = None |
| | self.prev_step = None |
| |
|
| | self.pause_start = None |
| | self.paused_time = 0 |
| |
|
| | self.total_steps = None |
| | self.global_bs = None |
| | self.steps_per_epoch = None |
| |
|
| | self.warmup = 2 |
| | self.load() |
| | self.note = "Chrono n/a" |
| |
|
| | def inform(self, *, first_step=None, total_steps=None, global_bs=None, |
| | steps_per_epoch=None, measure=None, write_note=None): |
| | """Provide some extra info that's only known later in the program.""" |
| | |
| | |
| | |
| | self.prev_step = first_step if first_step is not None else self.prev_step |
| | self.total_steps = total_steps or self.total_steps |
| | self.steps_per_epoch = steps_per_epoch or self.steps_per_epoch |
| | self.global_bs = global_bs or self.global_bs |
| | self._measure = measure or self._measure |
| | self._write_note = write_note or self._write_note |
| | if self.total_steps and self.prev_step is not None: |
| | self.note = (f"Steps:{self.prev_step}/{self.total_steps} " |
| | f"[{self.prev_step/self.total_steps:.1%}]") |
| |
|
| | def tick(self, step, measure=None, write_note=None): |
| | """A chronometer tick.""" |
| | if step == self.prev_step: return |
| |
|
| | measure = measure or self._measure |
| | write_note = write_note or self._write_note |
| |
|
| | now = time.monotonic() |
| | measure("uptime", now - self.program_start_time) |
| | self.flush_timings() |
| |
|
| | |
| | |
| | ds = step - self.prev_step |
| | self.prev_step = step |
| | self.accum_examples_seen += ds * self.global_bs |
| | measure("examples_seen", self.accum_examples_seen) |
| | measure("progress", step / self.total_steps) |
| | if self.steps_per_epoch: |
| | measure("epoch", step / self.steps_per_epoch) |
| |
|
| | |
| | |
| | |
| | if self.warmup > 1: |
| | self.warmup -= 1 |
| | write_note(self.note) |
| | return |
| | if self.warmup == 1: |
| | self.train_start_time = self.prev_time = now |
| | self.train_start_step = step |
| | self.accum_program_time += now - self.program_start_time |
| | self.paused_time = 0 |
| | self.warmup = 0 |
| | write_note(self.note) |
| | return |
| |
|
| | |
| | |
| | dt = now - self.prev_time - self.paused_time |
| | ncores = jax.device_count() |
| | measure("img/sec/core", self.global_bs * ds / dt / ncores) |
| |
|
| | |
| | self.accum_train_time += dt |
| | self.accum_pause_time += self.paused_time |
| | self.accum_program_time += dt + self.paused_time |
| |
|
| | |
| | core_hours = self.accum_train_time * ncores / 60 / 60 |
| | devtype = jax.devices()[0].device_kind |
| | measure(f"core_hours_{devtype}", core_hours) |
| | measure("core_hours", core_hours) |
| |
|
| | |
| | |
| | dt = now - self.train_start_time |
| | steps_timed = step - self.train_start_step |
| | steps_todo = self.total_steps - step |
| | self.note = f"Steps:{step}/{self.total_steps} [{step/self.total_steps:.1%}]" |
| | self.note += f"\nWalltime:{hms(self.accum_program_time)}" |
| | self.note += f" ({hms(self.accum_pause_time)} eval)" |
| | self.note += f"\nETA:{hms(dt / steps_timed*steps_todo)}" |
| | self.note += f"\nTotal train time:{hms(dt / steps_timed*self.total_steps)}" |
| | write_note(self.note) |
| |
|
| | log_memory(measure) |
| |
|
| | self.prev_time = now |
| | self.paused_time = 0 |
| |
|
| | def pause(self, wait_for=()): |
| | assert self.pause_start is None, "Don't pause twice." |
| | jax.block_until_ready(wait_for) |
| | self.pause_start = time.monotonic() |
| |
|
| | def resume(self): |
| | self.paused_time += time.monotonic() - self.pause_start |
| | self.pause_start = None |
| |
|
| | def save(self): |
| | return dict( |
| | accum_program_time=self.accum_program_time, |
| | accum_train_time=self.accum_train_time, |
| | accum_pause_time=self.accum_pause_time, |
| | accum_examples_seen=self.accum_examples_seen, |
| | ) |
| |
|
| | def load(self, ckpt={}): |
| | self.accum_program_time = float(ckpt.get("accum_program_time", 0.0)) |
| | self.accum_train_time = float(ckpt.get("accum_train_time", 0.0)) |
| | self.accum_pause_time = float(ckpt.get("accum_pause_time", 0.0)) |
| | self.accum_examples_seen = int(ckpt.get("accum_examples_seen", 0)) |
| |
|
| | @contextlib.contextmanager |
| | def log_timing(self, name, *, noop=False): |
| | """Use this when you time sth once per step and want instant flushing.""" |
| | t0 = time.monotonic() |
| | yield |
| | dt = time.monotonic() - t0 |
| | if not noop: |
| | if self._measure: |
| | self._measure(name, dt) |
| | logging.info("TIMING[%s]: %s", name, dt) |
| | logging.flush() |
| |
|
| | @contextlib.contextmanager |
| | def log_timing_avg(self, name, *, noop=False): |
| | """Use this when you time sth multiple times per step (eg in a loop).""" |
| | t0 = time.monotonic() |
| | yield |
| | dt = time.monotonic() - t0 |
| | if not noop: |
| | self._timing_history[name].append(dt) |
| | logging.info("TIMING[%s]: avg %s current %s", |
| | name, np.mean(self._timing_history[name]), dt) |
| | logging.flush() |
| |
|
| | def flush_timings(self): |
| | assert self._measure is not None |
| | for name, times in self._timing_history.items(): |
| | self._measure(name, np.mean(times)) |
| | self._timing_history.clear() |
| |
|
| |
|
| | |
| | chrono = Chrono() |
| |
|
| |
|
| | def log_memory(measure): |
| | """Log a bunch of memory-related measurements.""" |
| | try: |
| | import psutil |
| | except ImportError: |
| | psutil = None |
| |
|
| | if psutil is not None: |
| | |
| | vmem = psutil.virtual_memory() |
| | measure("y/hostmem/total", vmem.total) |
| | measure("y/hostmem/available", vmem.available) |
| | measure("y/hostmem/used", vmem.used) |
| |
|
| | |
| | |
| | |
| | for i, d in zip([0, 1], jax.local_devices()): |
| | for k, v in (d.memory_stats() or {}).items(): |
| | measure(f"y/devmem/dev{i}/{k}", v) |
| |
|
| |
|
| | def _traverse_with_names(tree, with_inner_nodes=False): |
| | """Traverses nested dicts/dataclasses and emits (leaf_name, leaf_val).""" |
| | if dataclasses.is_dataclass(tree): |
| | tree = flax.serialization.to_state_dict(tree) |
| | |
| | |
| | |
| | if tree is None: |
| | return |
| | elif isinstance(tree, Mapping): |
| | keys = sorted(tree.keys()) |
| | for key in keys: |
| | for path, v in _traverse_with_names(tree[key], with_inner_nodes): |
| | yield (key + "/" + path).rstrip("/"), v |
| | if with_inner_nodes: |
| | yield "", tree |
| | elif isinstance(tree, (list, tuple)): |
| | for idx in range(len(tree)): |
| | for path, v in _traverse_with_names(tree[idx], with_inner_nodes): |
| | yield (str(idx) + "/" + path).rstrip("/"), v |
| | if with_inner_nodes: |
| | yield "", tree |
| | else: |
| | yield "", tree |
| |
|
| |
|
| | def tree_flatten_with_names(tree): |
| | """Populates tree_flatten with leaf names. |
| | |
| | This function populates output of tree_flatten with leaf names, using a |
| | custom traversal that produces names is provided. The custom traversal does |
| | NOT have to traverse tree in the same order as jax, as we take care of |
| | automatically aligning jax' and custom traversals. |
| | |
| | Args: |
| | tree: python tree. |
| | |
| | Returns: |
| | A list of values with names: [(name, value), ...] |
| | """ |
| | vals, tree_def = jax.tree.flatten(tree) |
| |
|
| | |
| | |
| | tokens = range(len(vals)) |
| | token_tree = tree_def.unflatten(tokens) |
| | val_names, perm = zip(*_traverse_with_names(token_tree)) |
| | inv_perm = np.argsort(perm) |
| |
|
| | |
| | assert len(val_names) == len(vals) |
| |
|
| | return [(val_names[i], v) for i, v in zip(inv_perm, vals)], tree_def |
| |
|
| |
|
| | def tree_unflatten(names_and_vals): |
| | """Reverses `tree_flatten_with_names(tree)[0]`.""" |
| | return recover_tree(*zip(*names_and_vals)) |
| |
|
| |
|
| | def tree_map_with_names(f, tree, *rest): |
| | """Like jax.tree.map but with a filter on the leaf path name. |
| | |
| | Args: |
| | f: A function with first parameter `name` (path-like "a/b/c") and remaining |
| | parameters values of `tree` and `*rest` corresponding to the given `name` |
| | Should return a new value for parameter `name`. |
| | tree: The tree of parameters `f` should be applied to. |
| | *rest: more trees of the exact same structure. |
| | |
| | Returns: |
| | A tree identical in structure to `tree` and `*rest` but with the leaves the |
| | result of calling `f` on corresponding name/leaves in `tree` and `*rest`. |
| | """ |
| | names_and_vals, tree_def = tree_flatten_with_names(tree) |
| | names, vals = zip(*names_and_vals) |
| | rest_vals = [list(zip(*tree_flatten_with_names(t)[0]))[1] for t in rest] |
| | vals = [f(*name_and_vals) for name_and_vals in zip(names, vals, *rest_vals)] |
| | return tree_def.unflatten(vals) |
| |
|
| |
|
| | def tree_map_with_regex(f, tree, regex_rules, not_f=lambda x: x, name=None): |
| | """Apply jax-style tree_map based on regex rules. |
| | |
| | Args: |
| | f: a function that is being applied to every variable. |
| | tree: jax tree of arrays. |
| | regex_rules: a list of tuples `(pattern, args)`, where `pattern` is a regex |
| | which used for variable matching and `args` are positional arguments |
| | passed to `f`. If some variable is not matched, we apply `not_f` transform |
| | which is id by default. If multiple patterns match, then only the first |
| | rule is applied. |
| | not_f: optional function which is applied to variables that do not match any |
| | pattern. |
| | name: a name of transform for logging purposes. |
| | |
| | Returns: |
| | a tree, transformed by `f` according to the given rules. |
| | """ |
| | def _f(vname, v): |
| | for pattern, arg in regex_rules: |
| | if re.fullmatch(pattern, vname): |
| | if name and jax.process_index() == 0: |
| | logging.info("Applying %s to %s with %s due to `%s`", |
| | name, vname, arg, pattern) |
| | return f(v, arg) |
| | return not_f(v) |
| | return tree_map_with_names(_f, tree) |
| |
|
| |
|
| | def tree_get(tree, name): |
| | """Get an entry of pytree by flattened key name, eg a/b/c, with nice error. |
| | |
| | Args: |
| | tree: the pytree to be queried. |
| | name: the path to extract from the tree, see below for examples. |
| | |
| | Returns: |
| | A few examples: |
| | tree = {'a': 1, 'b': {'c': 2, 'd': 3}} |
| | tree_get(tree, 'a') == 1 |
| | tree_get(tree, 'b/c') == 2 |
| | tree_get(tree, 'b') == {'c': 2, 'd': 3} |
| | """ |
| | flattened = dict(_traverse_with_names(tree, with_inner_nodes=True)) |
| | try: |
| | return flattened[name] |
| | except KeyError as e: |
| | class Msg(str): |
| | def __repr__(self): |
| | return str(self) |
| | msg = "\n".join([name, "Available keys:", *flattened, ""]) |
| | |
| | msg = mlc.ConfigDict(flattened)._generate_did_you_mean_message(name, msg) |
| | raise KeyError(Msg(msg)) from e |
| |
|
| |
|
| | def tree_replace(tree, replacements): |
| | """Renames/removes (nested) keys. |
| | |
| | Example usage: |
| | |
| | tree = {'a': {'b': 2, 'c': 3}, 'c': 4} |
| | replacements = { |
| | 'a/b': 'a/b/x', # replaces 'a/b' with 'a/b/x' |
| | '.*c': 'C', # replaces 'c' with 'C' ('a/c' is removed) |
| | 'C': 'D', # replaces 'C' (which was 'c') with 'D' |
| | '.*/c': None, # removes 'a/c' |
| | } |
| | tree2 = rename_remove(tree, replacements) |
| | assert tree2 == {'D': 4, 'a': {'b': {'x': 2}}} |
| | |
| | Args: |
| | tree: A nested dictionary. |
| | replacements: Rules specifying `regex` as keys and `replacement` as values |
| | to be used with `m = re.match(regex, key)` and `m.expand(replacement)` |
| | for every `key` independently. |
| | |
| | Note that: |
| | 1. If any rule matches with `replacement=None`, then the key is removed. |
| | 2. The rules are applied in order. It's possible to have multiple |
| | transformations on a single key. |
| | |
| | Returns: |
| | Updated `tree` according to rules defined in `replacements`. |
| | """ |
| | replacements = { |
| | re.compile(kk): vv for kk, vv in replacements.items() |
| | } |
| |
|
| | def rename(k): |
| | for kk, vv in replacements.items(): |
| | m = kk.match(k) |
| | if m: |
| | k = k[:m.start()] + m.expand(vv) + k[m.end():] |
| | return k |
| |
|
| | def should_remove(k): |
| | return any(vv is None and kk.match(k) for kk, vv in replacements.items()) |
| |
|
| | names_and_vals, _ = tree_flatten_with_names(tree) |
| | names_and_vals = [ |
| | (rename(k), v) for k, v in names_and_vals if not should_remove(k) |
| | ] |
| | return tree_unflatten(names_and_vals) |
| |
|
| |
|
| | def tree_compare(tree1, tree2): |
| | """Returns `(tree1_only, tree2_only, dtype_shape_mismatch)`.""" |
| | tree1 = flax.traverse_util.flatten_dict(tree1, sep="/") |
| | tree2 = flax.traverse_util.flatten_dict(tree2, sep="/") |
| | return set(tree1) - set(tree2), set(tree2) - set(tree1), { |
| | k: [(v.dtype, v.shape), (tree2[k].dtype, tree2[k].shape)] |
| | for k, v in tree1.items() |
| | if k in tree2 and (v.dtype != tree2[k].dtype or v.shape != tree2[k].shape) |
| | } |
| |
|
| |
|
| | def tree_filter(tree, mask): |
| | """Returns nested dict structure with only a subset of children.""" |
| | |
| | |
| | if not isinstance(tree, dict): |
| | assert isinstance(mask, bool), f"Mask leaves must be boolean! {mask}" |
| | return tree |
| | assert sorted(tree.keys()) == sorted(mask.keys()), ( |
| | f"Keys in tree and mask are not equal! {tree.keys()} != {mask.keys()}") |
| | return {k: tree_filter(v, mask[k]) for k, v in tree.items() |
| | if mask[k] is not False} |
| |
|
| |
|
| | def recover_dtype(a): |
| | """Numpy's `save` stores bfloat16 type as "void" type, so we recover it.""" |
| | if hasattr(a, "dtype") and a.dtype.type is np.void: |
| | assert a.itemsize == 2, "Unknown dtype!" |
| | return a.view(jax.numpy.bfloat16) |
| | else: |
| | return a |
| |
|
| |
|
| | def recover_tree(keys, values): |
| | """Recovers a tree as a nested dict from flat names and values. |
| | |
| | This function is useful to analyze checkpoints that are saved by our programs |
| | without need to access the exact source code of the experiment. In particular, |
| | it can be used to extract an reuse various subtrees of the scheckpoint, e.g. |
| | subtree of parameters. |
| | |
| | Args: |
| | keys: a list of keys, where '/' is used as separator between nodes. |
| | values: a list of leaf values. |
| | |
| | Returns: |
| | A nested tree-like dict. |
| | """ |
| | tree = {} |
| | sub_trees = collections.defaultdict(list) |
| | for k, v in zip(keys, values): |
| | if "/" not in k: |
| | tree[k] = v |
| | else: |
| | k_left, k_right = k.split("/", 1) |
| | sub_trees[k_left].append((k_right, v)) |
| | for k, kv_pairs in sub_trees.items(): |
| | k_subtree, v_subtree = zip(*kv_pairs) |
| | tree[k] = recover_tree(k_subtree, v_subtree) |
| | return tree |
| |
|
| |
|
| | def tssave(mngr, pytree, path, on_commit=lambda *_, **__: None): |
| | """Save pytree using jax tensorstore-based checkpoint manager. |
| | |
| | NOTE: When overwriting an existing checkpoint with a different pytree, the |
| | result is, counterintuitively, the union of both, not only the new one. |
| | |
| | Args: |
| | mngr: An instance of GlobalAsyncCheckpointManager. |
| | pytree: What to store; any pytree of arrays. |
| | path: Where to save the pytree. Creates subfolders as needed. |
| | on_commit: A callback when writing is done, see `mngr.serialize`. |
| | """ |
| | names, vals = zip(*tree_flatten_with_names(pytree)[0]) |
| |
|
| | for name in names: |
| | if "~" in name: |
| | raise ValueError(f"Symbol '~' is not allowed in names. Found in {name}.") |
| |
|
| | gfile.makedirs(path) |
| | with jax.transfer_guard("allow"): |
| | names = [name.replace("/", "~") for name in names] |
| | mngr.serialize_with_paths( |
| | list(vals), [os.path.join(path, name) for name in names], |
| | on_commit_callback=functools.partial(on_commit, array_names=names)) |
| |
|
| |
|
| | def save_checkpoint_ts(mngr, checkpoint, path, step, keep=True): |
| | """Preemption-safe saving of checkpoints using tssave.""" |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def _on_commit_callback(array_names): |
| | with gfile.GFile(f"{path}-CUR", "w") as f: |
| | f.write(curr) |
| |
|
| | last = "" |
| | if gfile.exists(f"{path}-LAST"): |
| | with gfile.GFile(f"{path}-LAST", "r") as f: |
| | last = f.read().strip() |
| |
|
| | gfile.rename(f"{path}-CUR", f"{path}-LAST", overwrite=True) |
| |
|
| | if last.endswith("-tmp"): |
| | |
| | multiprocessing.pool.ThreadPool().map( |
| | gfile.rmtree, |
| | [f"{path}-{last}/{name}" for name in array_names]) |
| | gfile.rmtree(f"{path}-{last}") |
| |
|
| | |
| | |
| |
|
| | |
| | curr = f"{step:09d}{'-tmp' if not keep else ''}" |
| | tssave(mngr, checkpoint, f"{path}-{curr}", _on_commit_callback) |
| |
|
| |
|
| | def load_checkpoint_ts(path, **tsload_kw): |
| | """Loads a big_vision checkpoint saved by `save_checkpoint_ts`.""" |
| | to_load = path |
| |
|
| | try: |
| | |
| | with gfile.GFile(f"{path}-LAST", "r") as f: |
| | to_load = f"{path}-{f.read().strip()}" |
| | except Exception: |
| | pass |
| |
|
| | return tsload(to_load, **tsload_kw) |
| |
|
| |
|
| | def tsload(path, *, tree=None, shardings=None, regex=None): |
| | """Loads tensorstore-based array-tree from disk. |
| | |
| | If `tree` argument is provided, then array names to load and target structure |
| | is derived from the tree. If `tree` is None, then array names to load are |
| | derived from array filenames on the disk, and, optionally, `regex` is applied |
| | to filter these names. The`tree` argument is then automatically derived from |
| | array names with `recover_tree` util. |
| | |
| | Arrays are loaded to CPU/TPU/GPU memory as specified by the `shardings` |
| | argument, which is a pytree of CPU/TPU/GPU shardings (can be mixed within a |
| | single pytree). `shardings` should a prefix tree of the `tree` argument. We |
| | automatically broadcast `shardings` to a full `tree`. For example, a user can |
| | specify `shardings=jax.sharding.SingleDeviceSharing(jax.devices('cpu')[0])`, |
| | which will be broadcasted to a full tree. |
| | |
| | Args: |
| | path: a directory where the checkpoint arrays are stored. |
| | tree: a target pytree, which defines array names to load and the target tree |
| | structure. If tree is None, then `tree` is inferred from the names of |
| | arrays stored on the disk. |
| | shardings: a prefix pytree (with respect to `tree`) of the target shardings. |
| | regex: regex to filter array names from the disk, if `tree` is not provided. |
| | |
| | Returns: |
| | A pytree of loaded arrays that has the same structure as `shardings` arg. |
| | """ |
| | if (tree is not None) and (regex is not None): |
| | raise ValueError("If tree is specified, regex filtering is not allowed.") |
| |
|
| | if tree is None: |
| | |
| | path_names = set([p.rstrip("/").replace("~", "/") |
| | for p in gfile.listdir(path)]) |
| | regex = re.compile(regex) if regex is not None else re.compile(".*") |
| | path_names = [p for p in path_names if regex.match(p)] |
| | tree = recover_tree(path_names, [0] * len(path_names)) |
| |
|
| | names_and_vals, tree_def = tree_flatten_with_names(tree) |
| | names_to_load, _ = zip(*names_and_vals) |
| |
|
| | if shardings is None: |
| | shardings = jax.sharding.SingleDeviceSharding( |
| | jax.local_devices(backend="cpu")[0] |
| | ) |
| | shardings = list(jax.tree.leaves(tree_broadcast(shardings, tree))) |
| |
|
| | names_to_load = [os.path.join(path, name.replace("/", "~")) |
| | for name in names_to_load] |
| | specs = [array_serial.get_tensorstore_spec(n) for n in names_to_load] |
| | arrays = array_serial.run_deserialization(shardings, specs, concurrent_gb=64) |
| | return tree_def.unflatten(arrays) |
| |
|
| |
|
| | def steps(prefix, config, data_size=None, batch_size=None, total_steps=None, |
| | default=ValueError): |
| | """Gets duration named `prefix` out of `config` and converts it to steps. |
| | |
| | Using this function to access a configuration value that denotes some kind |
| | of duration (eg training time, warmup, checkpoint frequency, ...) allows the |
| | duration to be specified in terms of steps, epochs, examples, or percent of |
| | training time, and converts any of these into steps, such that the training |
| | code only deals with steps. |
| | If the result is not an integer step number, it is rounded to the nearest one. |
| | |
| | Args: |
| | prefix: The name of the duration to query. The actual config fields can |
| | then be one of `prefix_steps`, `prefix_examples`, or `prefix_epochs`. |
| | config: The dictionary (config) from which to read the duration. |
| | data_size: The total number of training examples in one epoch. |
| | batch_size: The number of examples processed per step. |
| | total_steps: The total number of training steps to run. |
| | default: The default value to return when no duration of the name `prefix` |
| | is found in the `config`. Set to `ValueError` (the default) to raise an |
| | error instead of returning a default value. |
| | |
| | Returns: |
| | The number of steps from the config, or the default value. |
| | |
| | Raises: |
| | ValueError if there is no such duration in the config and no default is set. |
| | """ |
| | |
| | suffixes = {"steps", "examples", "epochs", "percent"} |
| | matches = { |
| | f"{prefix}_{s}" |
| | for s in suffixes |
| | if (x := config.get(f"{prefix}_{s}")) is not None and x >= 0 |
| | } |
| | |
| | assert len(matches) <= 1, f"Only one of '{matches}' should be defined." |
| |
|
| | if f"{prefix}_steps" in matches: |
| | return config[f"{prefix}_steps"] |
| |
|
| | def to_integer(x): |
| | |
| | |
| | return max(1, round(x)) if x else 0 |
| |
|
| | if batch_size and f"{prefix}_examples" in matches: |
| | return to_integer(config[f"{prefix}_examples"] / batch_size) |
| |
|
| | if batch_size and data_size and f"{prefix}_epochs" in matches: |
| | steps_per_epoch = data_size / batch_size |
| | return to_integer(config[f"{prefix}_epochs"] * steps_per_epoch) |
| |
|
| | if total_steps and f"{prefix}_percent" in matches: |
| | pct = config[f"{prefix}_percent"] |
| | assert 0.0 <= pct <= 1.0, ( |
| | f"Percents should lie in [0.0, 1.0], but {prefix}_percent is {pct}") |
| | return to_integer(pct * total_steps) |
| |
|
| | if default is ValueError: |
| | raise ValueError( |
| | f"Cannot convert {prefix} to steps, due to missing batch_size " |
| | f"({batch_size}), data_size ({data_size}), total_steps ({total_steps})" |
| | ", or corresponding entry in config:\n" + "\n".join(config.keys())) |
| |
|
| | return default |
| |
|
| |
|
| | def create_learning_rate_schedule( |
| | total_steps, batch_size=None, data_size=None, |
| | base=1.0, decay_type="stair", |
| | scale_with_batchsize=False, **kw): |
| | """Creates learning rate schedule, see (internal link). |
| | |
| | Args: |
| | total_steps: The total number of steps to run. |
| | batch_size: The global batch-size optionally used for scaling. |
| | data_size: Number of examples in the training data (for epoch conversion). |
| | base: The starting learning-rate (without warmup). |
| | decay_type: 'linear' or 'cosine', 'rsqrt', 'stair'. |
| | scale_with_batchsize: Whether or not to scale lr automatically. |
| | **kw: extra arguments specific to individual decay_types. Also contains |
| | declaration of `{warmup,cooldown}_{steps,epochs,examples}` that applies |
| | on top of any/all decay_type. |
| | |
| | Returns: |
| | A function learning_rate(step): float -> {"learning_rate": float}. |
| | """ |
| |
|
| | def to_steps(name, default=0): |
| | return steps(name, kw, data_size, batch_size, total_steps, default=default) |
| |
|
| | warmup_steps = to_steps("warmup") |
| | cooldown_steps = to_steps("cooldown") |
| |
|
| | |
| | |
| | assert (total_steps <= 1) or (warmup_steps < total_steps), ( |
| | "warmup_steps is >= total_steps") |
| |
|
| | def step_fn(step): |
| | """Step to learning rate function.""" |
| | lr = base |
| |
|
| | |
| | |
| | |
| | |
| | if scale_with_batchsize: |
| | lr = lr * batch_size / 256.0 |
| |
|
| | progress = (step - warmup_steps) / float(total_steps - warmup_steps) |
| | progress = jnp.clip(progress, 0.0, 1.0) |
| | if decay_type in ("linear", "polynomial"): |
| | power = kw.get("power", 1) |
| | zero = kw.get("end", kw.get("linear_end", 0)) |
| | lr = zero + (lr - zero) * (1.0 - progress) ** power |
| | elif decay_type == "cosine": |
| | lr = lr * 0.5 * (1. + jnp.cos(jnp.pi * progress)) |
| | elif decay_type == "rsqrt": |
| | |
| | |
| | t = to_steps("timescale", default=kw.get("timescale", 10_000)) |
| | shift = to_steps("shift", default=kw.get("shift", 0)) |
| | lr = jnp.where( |
| | warmup_steps <= step, |
| | lr / jnp.sqrt(1 + (step + shift - warmup_steps) / t), |
| | lr / jnp.sqrt(1 + shift / t)) |
| | elif decay_type == "stair": |
| | i = jnp.searchsorted(jnp.array(kw.get("steps", [])), step + 1) |
| | lr = lr * jnp.take(jnp.array([1.0] + list(kw.get("mults", []))), i) |
| | else: |
| | raise ValueError(f"Unknown lr type {decay_type}") |
| |
|
| | if warmup_steps: |
| | lr = lr * jnp.minimum(1., step / warmup_steps) |
| | if cooldown_steps: |
| | lr = lr * jnp.minimum(1., (total_steps - step) / cooldown_steps) |
| |
|
| | return jnp.asarray(lr, dtype=jnp.float32) |
| |
|
| | return step_fn |
| |
|
| |
|
| | def get_mixup(rng, p): |
| | """Perform mixup https://arxiv.org/abs/1710.09412.""" |
| | rng, rng_mixup = jax.random.split(rng) |
| | a = jax.random.beta(rng_mixup, p, p) |
| | a = jnp.maximum(a, 1.0 - a) |
| | def _mixup(*things, **more_things): |
| | mix = lambda thing: a * thing + (1 - a) * jnp.roll(thing, shift=1, axis=0) |
| | return rng, *jax.tree.map(mix, (things, more_things)) |
| | return _mixup |
| |
|
| |
|
| | |
| | def mixup(rng, *things, p, **more_things): |
| | return get_mixup(rng, p)(*things, **more_things) |
| |
|
| |
|
| | def sync(): |
| | """Syncs hosts and empties async computation queue.""" |
| | x = reshard(np.ones(jax.device_count()), |
| | jax.sharding.PositionalSharding(jax.devices())) |
| | jax.jit(jnp.sum)(x).block_until_ready() |
| |
|
| |
|
| | def check_and_compile_patterns(patterns): |
| | """Validates and compiles a list of param-patterns. |
| | |
| | The validation consists of checking for common mistakes, currently only that |
| | the pattern does not start with a slash, because unlike FLAX, our parameter |
| | names don't start with a slash. |
| | |
| | Args: |
| | patterns: a single (string) pattern (regex), or a list of patterns. |
| | |
| | Returns: |
| | A list of compiled and verified regexes. |
| | """ |
| | if isinstance(patterns, str): |
| | patterns = [patterns] |
| |
|
| | assert isinstance(patterns, (list, tuple)), patterns |
| |
|
| | def check_and_compile(pattern): |
| | assert not pattern.startswith("/"), ( |
| | f"Big vision parameter names never start with '/': '{pattern}") |
| | return re.compile(pattern) |
| |
|
| | return list(map(check_and_compile, patterns)) |
| |
|
| |
|
| | def make_mask_trees(tree, patterns, *, log=None): |
| | """Returns a boolean mask tree for every pattern (only first match).""" |
| | compiled_patterns = check_and_compile_patterns(patterns) |
| |
|
| | def matchfirst(name, _): |
| | matches = [] |
| | for pattern in compiled_patterns: |
| | matches.append(not any(matches) and bool(pattern.fullmatch(name))) |
| | if log is not None and True in matches and jax.process_index() == 0: |
| | logging.info("%s: %s - matched by %s", log, name, |
| | patterns[matches.index(True)]) |
| | return np.array(matches) |
| |
|
| | multimask = tree_map_with_names(matchfirst, tree) |
| | return [ |
| | jax.tree.map(lambda matches, i=idx: matches[i], multimask) |
| | for idx in range(len(patterns)) |
| | ] |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def profile(name, ttl=3 * 365 * 24 * 3600, noop=False): |
| | if not noop: |
| | sess = startstop_prof_at_steps(None, name=name, ttl=ttl) |
| | yield |
| | if not noop: |
| | startstop_prof_at_steps(sess, name=name, ttl=ttl) |
| |
|
| |
|
| | def startstop_prof(sess, step=None, first_step=0, |
| | log_steps=1, surround=10, **kw): |
| | """Runs the profiler for `surround` steps around the next `log_steps`.""" |
| | first_log = first_step + log_steps - (first_step % log_steps) |
| | |
| | start = max(first_log - surround//2, first_step + 1) |
| | return startstop_prof_at_steps(sess, step, start, start + surround, **kw) |
| |
|
| |
|
| | def startstop_prof_at_steps( |
| | sess, step=None, first_step=None, last_step=None, |
| | name="steps", ttl=3 * 365 * 24 * 3600): |
| | del sess, step, first_step, last_step, name, ttl |
| | pass |
| |
|
| |
|
| | |
| | |
| | class BigVisionMetricWriter: |
| | """A class for logging metrics.""" |
| |
|
| | def __init__(self, xid=-1, wid=-1, workdir=None, config=None): |
| | self.step_start(0) |
| | if jax.process_index() != 0: return |
| |
|
| | self.pool = multiprocessing.pool.ThreadPool(1) |
| | self.fname = None |
| | if workdir: |
| | if xid != -1 and wid != -1: |
| | self.fname = os.path.join(workdir, |
| | f"big_vision_{xid}_{wid}_metrics.txt") |
| | else: |
| | self.fname = os.path.join(workdir, "big_vision_metrics.txt") |
| | if config: |
| | with gfile.GFile(os.path.join(workdir, "config.json"), "w") as f: |
| | f.write(config.to_json()) |
| |
|
| | def step_start(self, step): |
| | self.step = step |
| | self.step_metrics = {} |
| |
|
| | def measure(self, name, value): |
| | """Logs the metric value.""" |
| | if jax.process_index() != 0: return |
| |
|
| | |
| | |
| | value = np.array(value).squeeze() |
| |
|
| | |
| | |
| | value = float(value) if value.ndim == 0 else value.shape |
| |
|
| | logging.info(f"\u001b[35m[{self.step}]\u001b[0m {name} = {value}") |
| | logging.flush() |
| | self.step_metrics[name] = value |
| |
|
| | return value |
| |
|
| | def step_end(self): |
| | """Ends a training step, write its full row.""" |
| | if not self.step_metrics: return |
| |
|
| | def write(metrics): |
| | with gfile.GFile(self.fname, "a") as f: |
| | f.write(json.dumps({"step": self.step, **metrics}) + "\n") |
| |
|
| | if self.fname: |
| | self.pool.apply(lambda: None) |
| | self.pool.apply_async(write, (self.step_metrics,)) |
| |
|
| | def close(self): |
| | self.step_end() |
| | if jax.process_index() == 0: |
| | self.pool.close() |
| | self.pool.join() |
| |
|
| |
|
| | def maybe_cleanup_workdir(workdir, cleanup, info): |
| | """Potentially removes workdirs at end of run for cleanup.""" |
| | if not workdir: |
| | return |
| |
|
| | if not cleanup: |
| | info("Logs/checkpoints are in %s", workdir) |
| | elif jax.process_index() == 0: |
| | gfile.rmtree(workdir) |
| | try: |
| | gfile.remove(os.path.join(workdir, "..")) |
| | except tf.errors.OpError: |
| | pass |
| |
|
| |
|
| | def tree_broadcast(prefix, target): |
| | """Broadcasts a prefix tree to a full tree. |
| | |
| | Input-output examples: |
| | 1. prefix: {"x": 10, "y": 20} |
| | target: {"x": {"a": 1, "b": 2}, "y": 3} |
| | |
| | Result: {"x": {"a": 10, "b": 10}, "y": 20} |
| | |
| | 2. prefix: 100 |
| | target: {"x": {"a": 1, "b": 2}, "y": 3} |
| | |
| | Result: {"x": {"a": 100, "b": 100}, "y": 100} |
| | |
| | 3. prefix: {"x": 10} |
| | target: {"x": {"a": 1, "b": 2}, "y": 3} |
| | |
| | Result: ValueError |
| | |
| | Args: |
| | prefix: prefix pytree. |
| | target: boradcast target for a prefix tree. |
| | |
| | Returns: |
| | prefix tree broadcasted to a target tree. |
| | """ |
| | def _broadcast(leaf, subtree): |
| | return jax.tree.map(lambda _: leaf, subtree) |
| | return jax.tree.map(_broadcast, prefix, target) |
| |
|
| |
|
| | def reshard(tree, shardings): |
| | """Take an arbitrarily* sharded pytree and shard it according to `shardings`. |
| | |
| | This is a no-op for tree elements which are already sharded as requested. |
| | |
| | *Arrays that are fully addressable (for example, CPU arrays) are assumed to be |
| | identical (i.e. replicated) across hosts. |
| | |
| | *It does not work if an element of `tree` is not fully-addressable, unless its |
| | sharding is already consistent with the target sharding. |
| | If this is needed, please ping lbeyer@ or akolesnikov@. |
| | |
| | Args: |
| | tree: a pytree of arrays. |
| | shardings: a (prefix) pytree of jax array shardings. |
| | Returns: |
| | A pytree of global jax arrays that follows provided shardings. |
| | """ |
| | def _make_global_arr(x, shard, shape): |
| | |
| | if hasattr(x, "sharding") and x.sharding.is_equivalent_to(shard, len(shape)): |
| | return x |
| | if not getattr(x, "is_fully_addressable", True): |
| | raise RuntimeError("Trying to reshard a non-fully-addressable array. " |
| | "Please see the doc-comment for detailed explanation.") |
| | x = jax.device_get(x) |
| | xs = [jax.device_put(x[s], device=d) |
| | for d, s in shard.addressable_devices_indices_map(shape).items()] |
| | return jax.make_array_from_single_device_arrays(shape, shard, xs) |
| |
|
| | shapes = jax.tree.map(np.shape, tree) |
| | shardings = tree_broadcast(shardings, tree) |
| | return jax.tree.map(_make_global_arr, tree, shardings, shapes) |
| |
|
| |
|
| | def put_cpu(x): |
| | """Places array/pytree on a CPU device.""" |
| | return jax.device_put(x, jax.local_devices(backend="cpu")[0]) |
| |
|
| |
|
| | def make_fsarray_from_local_slice(local_slice, global_devices): |
| | """Create a fully-sharded global device array from local host arrays. |
| | |
| | Args: |
| | local_slice: Something convertible to a numpy array (eg also TF tensors) |
| | that is this host's slice of the global array. |
| | global_devices: The list of global devices. Needed for consistent ordering. |
| | |
| | Returns: |
| | The global on-device array which consists of all local slices stacked |
| | together in the order consistent with the devices. |
| | """ |
| | mesh = jax.sharding.Mesh(global_devices, ("devices",)) |
| | sharding = jax.sharding.NamedSharding( |
| | mesh, jax.sharding.PartitionSpec("devices")) |
| | local_ds = mesh.local_devices |
| |
|
| | x = np.asarray(memoryview(local_slice)) |
| | xs = jax.device_put(np.split(x, len(local_ds), axis=0), local_ds) |
| |
|
| | global_shape = (x.shape[0] * jax.process_count(), *x.shape[1:]) |
| | return jax.make_array_from_single_device_arrays(global_shape, sharding, xs) |
| |
|
| |
|
| | def get_local_slice_from_fsarray(global_array): |
| | """Return numpy array for the host-local slice of fully-sharded array. |
| | |
| | Args: |
| | global_array: JAX array, globally sharded on devices across hosts. |
| | |
| | Returns: |
| | NumPy array that holds the part of `global_array` that is held by the |
| | devices on the host that calls this function. |
| | """ |
| | |
| | for shard in global_array.addressable_shards: |
| | assert all(idx == slice(None) for idx in shard.index[1:]), ( |
| | f"global_array is sharded along non-first dimensions:\n{shard.index}") |
| |
|
| | |
| | |
| | |
| | m = {s.device: s for s in global_array.addressable_shards} |
| | local_shards = [m[d] for d in global_array.sharding.mesh.local_devices] |
| | return np.concatenate([jax.device_get(s.data) for s in local_shards], axis=0) |
| |
|
| |
|
| | def assert_local_slices_same(*global_arrays): |
| | """Check whether all `global_arrays` have local slices at the same indices.""" |
| | slices = [ |
| | tuple( |
| | tuple((idx.start, idx.end, idx.step) for idx in s.index) |
| | for s in a.addressable_shards) |
| | for a in global_arrays] |
| | assert len(set(slices)) == 1, f"Not all slices are the same: {slices}" |
| |
|
| |
|
| | |
| | |
| | def jit_cpu(**extra_kwargs): |
| | def _decorator(fun): |
| | def _wrapped(*args, **kwargs): |
| | sh = jax.sharding.SingleDeviceSharding( |
| | jax.local_devices(backend="cpu")[0] |
| | ) |
| | return jax.jit(fun, **extra_kwargs, out_shardings=sh)(*args, **kwargs) |
| | return _wrapped |
| | return _decorator |
| |
|
| |
|
| | def create_device_mesh( |
| | config_mesh, |
| | *, |
| | allow_split_physical_axes=False, |
| | ): |
| | """Returns a JAX device mesh. |
| | |
| | Args: |
| | config_mesh: A list of tuples of (axis_name, axis_size). It is advised to |
| | sort the axis in increasing order of network communication intensity. |
| | allow_split_physical_axes: Whether to allow splitting physical axes. |
| | """ |
| | devices = jax.devices() |
| | mesh_axes, mesh_size = tuple(zip(*config_mesh)) |
| | |
| | mesh_size = np.array(devices).reshape(mesh_size).shape |
| | device_mesh = mesh_utils.create_device_mesh( |
| | mesh_size, |
| | devices=devices, |
| | allow_split_physical_axes=allow_split_physical_axes) |
| | return jax.sharding.Mesh(device_mesh, mesh_axes) |
| |
|