import dataclasses import functools import logging import platform from typing import Any, Optional, Dict, Tuple import etils.epath as epath import flax.nnx as nnx from flax.training import common_utils import flax.traverse_util as traverse_util import jax import jax.experimental import jax.numpy as jnp import numpy as np import optax import tqdm_loggable.auto as tqdm import wandb import numpy as np import openpi.models.model as _model import openpi.shared.array_typing as at import openpi.shared.nnx_utils as nnx_utils import openpi.training.checkpoints as _checkpoints import openpi.training.config as _config import openpi.training.data_loader as _data_loader import openpi.training.optimizer as _optimizer import openpi.training.sharding as sharding import openpi.training.utils as training_utils import openpi.training.weight_loaders as _weight_loaders from flax.nnx import rnglib from openpi.models.pi0_fast import Pi0FAST, make_attn_mask @dataclasses.dataclass class OftTrainingConfig: """openvla-oft""" use_l1_regression: bool = False use_diffusion: bool = True use_discrete_tokens: bool = False num_diffusion_steps_train: int = 25 diffusion_beta_start: float = 0.0001 diffusion_beta_end: float = 0.00005 grad_accumulation_steps: int = 1 use_val_set: bool = False val_freq: int = 10_000 class DiffusionScheduler: def __init__(self, num_train_timesteps: int, beta_start: float = 0.0001, beta_end: float = 0.02): self.num_train_timesteps = num_train_timesteps self.beta_start = beta_start self.beta_end = beta_end self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps) self.alphas = 1.0 - self.betas self.alphas_cumprod = jnp.cumprod(self.alphas) self.alphas_cumprod_prev = jnp.concatenate([jnp.array([1.0]), self.alphas_cumprod[:-1]]) self.variance = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) self.variance = jnp.concatenate([jnp.array([0.0]), self.variance[1:]]) self.timesteps = jnp.arange(0, num_train_timesteps) def set_timesteps(self, num_inference_steps: int): self.num_inference_steps = num_inference_steps step_ratio = self.num_train_timesteps // num_inference_steps self.timesteps = jnp.arange(0, self.num_train_timesteps, step_ratio) def step(self, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray) -> Dict[str, jnp.ndarray]: # DDIM step alpha_cumprod = self.alphas_cumprod[timestep] alpha_cumprod_prev = self.alphas_cumprod_prev[timestep] # predict x_0 pred_original_sample = (sample - jnp.sqrt(1 - alpha_cumprod) * model_output) / jnp.sqrt(alpha_cumprod) # predict x_{t-1} pred_sample_direction = jnp.sqrt(1 - alpha_cumprod_prev) * model_output prev_sample = jnp.sqrt(alpha_cumprod_prev) * pred_original_sample + pred_sample_direction return {"prev_sample": prev_sample} class TimeEncoder(nnx.Module): def __init__(self, llm_dim: int, rngs: at.KeyArrayLike | None = None): super().__init__() self.llm_dim = llm_dim if rngs is None: rngs = jax.random.key(0) rngs_obj = rnglib.Rngs(params=rngs) self.time_embedding = nnx.Linear(1, llm_dim, rngs=rngs_obj) self.time_mlp = nnx.Sequential( nnx.Linear(llm_dim, llm_dim, rngs=rngs_obj), nnx.relu, nnx.Linear(llm_dim, llm_dim, rngs=rngs_obj), ) def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: # timesteps: (batch_size,) timesteps = timesteps.astype(jnp.float32) time_emb = self.time_embedding(timesteps[:, None]) # (batch_size, llm_dim) time_emb = self.time_mlp(time_emb) return time_emb class DiffusionActionHead(nnx.Module): def __init__(self, input_dim: int, hidden_dim: int, action_dim: int, num_diffusion_steps: int, rngs: at.KeyArrayLike | None = None): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.action_dim = action_dim self.num_diffusion_steps_train = num_diffusion_steps if rngs is None: rngs = jax.random.key(0) rngs_obj = rnglib.Rngs(params=rngs) # noise predictor self.noise_predictor = nnx.Sequential( nnx.Linear(input_dim, hidden_dim, rngs=rngs_obj), nnx.relu, nnx.Linear(hidden_dim, hidden_dim, rngs=rngs_obj), nnx.relu, nnx.Linear(hidden_dim, action_dim, rngs=rngs_obj), ) # time encoder self.time_encoder = TimeEncoder(hidden_dim, rngs=rngs) # diffusion scheduler self.noise_scheduler = DiffusionScheduler(num_diffusion_steps) def sample_noisy_actions(self, actions: jnp.ndarray, rng: at.KeyArrayLike) -> Dict[str, jnp.ndarray]: batch_size = actions.shape[0] # sample timesteps timesteps = jax.random.randint(rng, (batch_size,), 0, self.num_diffusion_steps_train) # generate noise noise = jax.random.normal(rng, actions.shape) # add noise to actions alpha_cumprod = self.noise_scheduler.alphas_cumprod[timesteps] alpha_cumprod = alpha_cumprod.reshape(-1, 1, 1) # (batch_size, 1, 1) noisy_actions = jnp.sqrt(alpha_cumprod) * actions + jnp.sqrt(1 - alpha_cumprod) * noise # time step encoding diffusion_timestep_embeddings = self.time_encoder(timesteps) return { "noise": noise, "noisy_actions": noisy_actions, "diffusion_timestep_embeddings": diffusion_timestep_embeddings, "timesteps": timesteps, } def predict_noise(self, hidden_states: jnp.ndarray) -> jnp.ndarray: return self.noise_predictor(hidden_states) class NoisyActionProjector(nnx.Module): def __init__(self, input_dim: int, llm_dim: int, rngs: at.KeyArrayLike | None = None): super().__init__() self.llm_dim = llm_dim if rngs is None: rngs = jax.random.key(0) rngs_obj = rnglib.Rngs(params=rngs) self.projection = nnx.Linear(input_dim, llm_dim, rngs=rngs_obj) def __call__(self, noisy_actions: jnp.ndarray) -> jnp.ndarray: return self.projection(noisy_actions) def init_logging(): """Custom logging format for better readability.""" level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} class CustomFormatter(logging.Formatter): def format(self, record): record.levelname = level_mapping.get(record.levelname, record.levelname) return super().format(record) formatter = CustomFormatter( fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", datefmt="%H:%M:%S", ) logger = logging.getLogger() logger.setLevel(logging.INFO) logger.handlers[0].setFormatter(formatter) def init_wandb(config: _config.TrainConfig, oft_config: OftTrainingConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True): if not enabled: wandb.init(mode="disabled") return ckpt_dir = config.checkpoint_dir if not ckpt_dir.exists(): raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") if resuming: run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() wandb.init(id=run_id, resume="must", project=config.project_name) else: # openvla-oft run_id run_id = f"{config.exp_name}+oft" # LoRA try: if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant): run_id += "+lora" except: pass if config.ema_decay is None: run_id += "+no_ema" # training mode if oft_config.use_l1_regression: run_id += "+l1_regression" if oft_config.use_diffusion: run_id += "+diffusion" if oft_config.use_discrete_tokens: run_id += "+discrete" wandb.init( name=run_id, config={ **dataclasses.asdict(config), **dataclasses.asdict(oft_config) }, project=config.project_name, ) if wandb.run is not None: (ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) if log_code and wandb.run is not None: wandb.run.log_code(str(epath.Path(__file__).parent.parent)) def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params: """Loads and validates the weights. Returns a loaded subset of the weights.""" loaded_params = loader.load(params_shape) at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True) # Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned. return traverse_util.unflatten_dict( {k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)} ) def apply_lora_to_model(model, config: _config.TrainConfig): # LoRA try: if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant): logging.info(f"Detected LoRA configuration: {config.model.paligemma_variant}") return model except: pass return model def create_diffusion_components(config: _config.TrainConfig, oft_config: OftTrainingConfig, rng: at.KeyArrayLike): if not oft_config.use_diffusion: return None, None llm_dim = 2048 # get from model config action_dim = config.model.action_dim action_horizon = config.model.action_horizon # create diffusion action head diffusion_action_head = DiffusionActionHead( input_dim=llm_dim, hidden_dim=llm_dim, action_dim=action_dim, num_diffusion_steps=oft_config.num_diffusion_steps_train, rngs=rng ) # create noisy action projector noisy_action_projector = NoisyActionProjector( input_dim=action_dim, # only use action_dim llm_dim=llm_dim, rngs=rng ) return diffusion_action_head, noisy_action_projector def lora_mask(tree): def is_lora(path, v): return any('lora' in str(p) for p in path) return jax.tree_util.tree_map_with_path(lambda path, v: is_lora(path, v), tree) @at.typecheck def init_train_state( config: _config.TrainConfig, oft_config: OftTrainingConfig, init_rng: at.KeyArrayLike, mesh: jax.sharding.Mesh, tx, *, resume: bool ) -> tuple[training_utils.TrainState, Any]: def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState: rng, model_rng = jax.random.split(rng) model = config.model.create(model_rng) model = apply_lora_to_model(model, config) diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, model_rng) if partial_params is not None: graphdef, state = nnx.split(model) state.replace_by_pure_dict(partial_params) model = nnx.merge(graphdef, state) params = nnx.state(model) params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16))) # use main tx return training_utils.TrainState( step=0, params=params, model_def=nnx.graphdef(model), tx=tx, opt_state=tx.init(params), ema_decay=config.ema_decay, ema_params=None if config.ema_decay is None else params, ) train_state_shape = jax.eval_shape(init, init_rng) state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True) if resume: return train_state_shape, state_sharding partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict()) replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) train_state = jax.jit( init, donate_argnums=(1,), in_shardings=replicated_sharding, out_shardings=state_sharding, )(init_rng, partial_params) return train_state, state_sharding # TODO: modify L1 loss in the future def compute_l1_loss(predicted_actions: jnp.ndarray, ground_truth_actions: jnp.ndarray) -> jnp.ndarray: return jnp.mean(jnp.abs(predicted_actions - ground_truth_actions)) def compute_diffusion_loss(predicted_noise: jnp.ndarray, target_noise: jnp.ndarray) -> jnp.ndarray: return jnp.mean((predicted_noise - target_noise) ** 2) def run_diffusion_sampling( model: _model.BaseModel, diffusion_action_head: DiffusionActionHead, noisy_action_projector: NoisyActionProjector, observation: _model.Observation, actions: _model.Actions, rng: at.KeyArrayLike, oft_config: OftTrainingConfig, ) -> jnp.ndarray: """diffusion sampling, main model and NoisyActionProjector are involved, adapt to Pi0FAST""" batch_size = actions.shape[0] action_dim = actions.shape[-1] action_horizon = actions.shape[1] # generate random noise as starting point noise = jax.random.normal(rng, (batch_size, action_horizon, action_dim)) # set diffusion scheduler diffusion_action_head.noise_scheduler.set_timesteps(oft_config.num_diffusion_steps_train) curr_noisy_actions = noise def diffusion_step(carry, timestep): curr_noisy_actions = carry timesteps = jnp.full((batch_size,), timestep) # time step embedding diffusion_timestep_embeddings = diffusion_action_head.time_encoder(timesteps) # (batch, llm_dim) diffusion_timestep_embeddings = jnp.expand_dims(diffusion_timestep_embeddings, 1) # (batch, 1, llm_dim) diffusion_timestep_embeddings = jnp.tile(diffusion_timestep_embeddings, (1, action_horizon, 1)) # (batch, action_horizon, llm_dim) # Pi0FAST if not isinstance(model, Pi0FAST): raise ValueError("run_diffusion_sampling only supports Pi0FAST main model!") obs_token_emb, input_mask, ar_mask = model.embed_inputs(observation) # (batch, obs_seq_len, llm_dim) # embedding noisy_action_emb = noisy_action_projector(curr_noisy_actions) # (batch, action_horizon, llm_dim) full_emb = jnp.concatenate([obs_token_emb, noisy_action_emb, diffusion_timestep_embeddings], axis=1) # (batch, obs_seq_len+2*action_horizon, llm_dim) # mask full_input_mask = jnp.concatenate([input_mask, jnp.ones((batch_size, 2*action_horizon), dtype=input_mask.dtype)], axis=1) full_ar_mask = jnp.concatenate([ar_mask, jnp.zeros((batch_size, 2*action_horizon), dtype=ar_mask.dtype)], axis=1) attn_mask = make_attn_mask(full_input_mask, full_ar_mask) attn_mask = attn_mask[:, None, :, :] # (batch, 1, seq_len, seq_len) # hidden_states hidden_states, _, _ = model.PaliGemma.llm( embedded_prefix=full_emb, mask=attn_mask, return_prelogits=True, ) obs_seq_len = obs_token_emb.shape[1] actions_hidden_states = hidden_states[:, obs_seq_len:obs_seq_len+action_horizon, :] # (batch, action_horizon, llm_dim) noise_pred = diffusion_action_head.predict_noise(actions_hidden_states) # (batch, action_horizon, action_dim) prev_sample = diffusion_action_head.noise_scheduler.step(noise_pred, timestep, curr_noisy_actions)["prev_sample"] return prev_sample, None final_sample, _ = jax.lax.scan(diffusion_step, curr_noisy_actions, diffusion_action_head.noise_scheduler.timesteps) return final_sample def compute_loss_with_oft_modes( model: _model.BaseModel, rng: at.KeyArrayLike, observation: _model.Observation, actions: _model.Actions, config: _config.TrainConfig, oft_config: OftTrainingConfig, diffusion_action_head: Optional[DiffusionActionHead] = None, noisy_action_projector: Optional[NoisyActionProjector] = None, train: bool = True ) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: """openvla-oft""" chunked_loss = model.compute_loss(rng, observation, actions, train=train) base_loss = jnp.mean(chunked_loss) metrics = {"loss": base_loss} # calculate different losses based on training mode if oft_config.use_discrete_tokens: # discrete token prediction mode (default) metrics["discrete_loss"] = base_loss elif oft_config.use_l1_regression: l1_loss = base_loss # TODO: calculate L1 loss metrics["l1_loss"] = l1_loss metrics["regression_loss"] = l1_loss elif oft_config.use_diffusion and diffusion_action_head is not None: # diffusion batch_size = actions.shape[0] action_horizon = actions.shape[1] action_dim = actions.shape[2] # sample noise noisy_dict = diffusion_action_head.sample_noisy_actions(actions, rng) noise = noisy_dict["noise"] # (batch, action_horizon, action_dim) noisy_actions = noisy_dict["noisy_actions"] # (batch, action_horizon, action_dim) diffusion_timestep_embeddings = noisy_dict["diffusion_timestep_embeddings"] # (batch, llm_dim) timesteps = noisy_dict["timesteps"] # hidden_states if not isinstance(model, Pi0FAST): raise ValueError("diffusion loss only supports Pi0FAST main model!") if noisy_action_projector is None: raise ValueError("diffusion loss needs noisy_action_projector, should not be None") # noisy_action_projector noisy_action_emb = noisy_action_projector(noisy_actions) # (batch, action_horizon, llm_dim) # diffusion_timestep_embeddings -> (batch, action_horizon, llm_dim) diffusion_timestep_embeddings = jnp.expand_dims(diffusion_timestep_embeddings, 1) # (batch, 1, llm_dim) diffusion_timestep_embeddings = jnp.tile(diffusion_timestep_embeddings, (1, action_horizon, 1)) # (batch, action_horizon, llm_dim) obs_token_emb, input_mask, ar_mask = model.embed_inputs(observation) # (batch, obs_seq_len, llm_dim) full_emb = jnp.concatenate([obs_token_emb, noisy_action_emb, diffusion_timestep_embeddings], axis=1) # (batch, obs_seq_len+2*action_horizon, llm_dim) full_input_mask = jnp.concatenate([input_mask, jnp.ones((batch_size, 2*action_horizon), dtype=input_mask.dtype)], axis=1) full_ar_mask = jnp.concatenate([ar_mask, jnp.zeros((batch_size, 2*action_horizon), dtype=ar_mask.dtype)], axis=1) attn_mask = make_attn_mask(full_input_mask, full_ar_mask) attn_mask = attn_mask[:, None, :, :] # (batch, 1, seq_len, seq_len) hidden_states, _, _ = model.PaliGemma.llm( embedded_prefix=full_emb, mask=attn_mask, return_prelogits=True, ) obs_seq_len = obs_token_emb.shape[1] # actions_hidden_state actions_hidden_states = hidden_states[:, obs_seq_len:obs_seq_len+action_horizon, :] # (batch, action_horizon, llm_dim) predicted_noise = diffusion_action_head.predict_noise(actions_hidden_states) # (batch, action_horizon, action_dim) # loss diffusion_loss = jnp.mean((predicted_noise - noise) ** 2) metrics["diffusion_loss"] = diffusion_loss metrics["noise_prediction_loss"] = diffusion_loss base_loss = diffusion_loss # LoRA try: if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant): metrics["lora_loss"] = base_loss metrics["finetune_mode"] = jnp.array(1.0) # mark as finetune mode except: pass return base_loss, metrics @at.typecheck def train_step( config: _config.TrainConfig, oft_config: OftTrainingConfig, rng: at.KeyArrayLike, state: training_utils.TrainState, batch: tuple[_model.Observation, _model.Actions], ) -> tuple[training_utils.TrainState, dict[str, at.Array]]: model = nnx.merge(state.model_def, state.params) model.train() train_rng = jax.random.fold_in(rng, state.step) observation, actions = batch diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, train_rng) # openvla-oft loss loss, metrics = compute_loss_with_oft_modes( model, train_rng, observation, actions, config, oft_config, diffusion_action_head, noisy_action_projector, train=True ) # Filter out frozen params. diff_state = nnx.DiffState(0, config.trainable_filter) grads = nnx.grad(lambda m, r, obs, acts: compute_loss_with_oft_modes( m, r, obs, acts, config, oft_config, diffusion_action_head, noisy_action_projector, train=True )[0])(model, train_rng, observation, actions) params = state.params #print(params) updates, new_opt_state = state.tx.update(grads, state.opt_state, params) new_params = optax.apply_updates(params, updates) # Update the model in place and return the new full state. new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state) if state.ema_decay is not None and state.ema_params is not None: ema_decay = state.ema_decay new_state = dataclasses.replace( new_state, ema_params=jax.tree.map( lambda old, new: ema_decay * old + (1 - ema_decay) * new, state.ema_params, new_params ), ) # Filter out params that aren't kernels. kernel_params = nnx.state( model, nnx.All( nnx.Param, nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")), lambda _, x: x.value.ndim > 1, ), ) info = { **metrics, "grad_norm": optax.global_norm(grads), "param_norm": optax.global_norm(kernel_params), } # sample actions for visualization/debug if diffusion_action_head is not None and noisy_action_projector is not None: sampled_actions = run_diffusion_sampling( model, diffusion_action_head, noisy_action_projector, observation, actions, rng, oft_config ) # only take the first batch element, avoid info too large info["sampled_actions"] = sampled_actions[:1] return new_state, info def run_validation( config: _config.TrainConfig, oft_config: OftTrainingConfig, state: training_utils.TrainState, val_data_loader, mesh: jax.sharding.Mesh, step: int, ) -> Dict[str, float]: """validation""" if not oft_config.use_val_set: return {} model = nnx.merge(state.model_def, state.params) model.eval() val_metrics = [] val_batches = 0 for batch in val_data_loader: if val_batches >= 10: # limit validation batches break observation, actions = batch # create diffusion components diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, jax.random.key(0)) loss, metrics = compute_loss_with_oft_modes( model, jax.random.key(0), observation, actions, config, oft_config, diffusion_action_head, noisy_action_projector, train=False ) val_metrics.append(metrics) val_batches += 1 # calculate average metrics avg_metrics = {} if val_metrics: for key in val_metrics[0].keys(): avg_metrics[f"val_{key}"] = jnp.mean(jnp.array([m[key] for m in val_metrics])) return avg_metrics def main(config: _config.TrainConfig): init_logging() logging.info(f"Running on: {platform.node()}") logging.info(f"Using openvla-oft enhanced training script") logging.info(f"Config: {config.name}") # openvla-oft config oft_config = OftTrainingConfig() if config.batch_size % jax.device_count() != 0: raise ValueError( f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}." ) jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser())) rng = jax.random.key(config.seed) train_rng, init_rng = jax.random.split(rng) mesh = sharding.make_mesh(config.fsdp_devices) data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS)) replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir( str(config.checkpoint_dir), keep_period=config.keep_period, overwrite=config.overwrite, resume=config.resume, ) init_wandb(config, oft_config, resuming=resuming, enabled=config.wandb_enabled) data_loader = _data_loader.create_data_loader( config, sharding=data_sharding, shuffle=True, ) data_iter = iter(data_loader) batch = next(data_iter) logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}") # Log images from first batch to sanity check. images_to_log = [ wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1)) for i in range(min(5, len(next(iter(batch[0].images.values()))))) ] wandb.log({"camera_views": images_to_log}, step=0) # initialize model, get all params (only for generating mask) model = config.model.create(init_rng) model = apply_lora_to_model(model, config) params = nnx.state(model) mask = lora_mask(params) # add gradient clipping, clip_norm=1.0 tx = optax.chain( optax.clip_by_global_norm(1.0), optax.masked( _optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None), mask ) ) train_state, train_state_sharding = init_train_state( config, oft_config, init_rng, mesh, tx=tx, resume=resuming ) jax.block_until_ready(train_state) logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}") if resuming: train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader) ptrain_step = jax.jit( functools.partial(train_step, config, oft_config), in_shardings=(replicated_sharding, train_state_sharding, data_sharding), out_shardings=(train_state_sharding, replicated_sharding), donate_argnums=(1,), ) start_step = int(jax.device_get(train_state.step)) pbar = tqdm.tqdm( range(start_step, config.num_train_steps), initial=start_step, total=config.num_train_steps, dynamic_ncols=True, ) infos = [] gradient_step = 0 for step in pbar: with sharding.set_mesh(mesh): train_state, info = ptrain_step(train_rng, train_state, batch) infos.append(info) if (step + 1) % oft_config.grad_accumulation_steps == 0: gradient_step += 1 if gradient_step % config.log_interval == 0: stacked_infos = common_utils.stack_forest(infos) reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos)) info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items()) pbar.write(f"Step {step}: {info_str}") wandb.log(reduced_info, step=step) infos = [] # validation if oft_config.use_val_set and gradient_step % oft_config.val_freq == 0: val_metrics = run_validation(config, oft_config, train_state, data_loader, mesh, step) if val_metrics: wandb.log(val_metrics, step=step) pbar.write(f"Validation at step {step}: {val_metrics}") batch = next(data_iter) if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1: _checkpoints.save_state(checkpoint_manager, train_state, data_loader, step) logging.info("Waiting for checkpoint manager to finish") checkpoint_manager.wait_until_finished() if __name__ == "__main__": main(_config.cli())