openpi-oft / scripts /train_oft.py
Sichang0621's picture
Upload folder using huggingface_hub
ce5618e verified
import dataclasses
import functools
import logging
import platform
from typing import Any, Optional, Dict, Tuple
import etils.epath as epath
import flax.nnx as nnx
from flax.training import common_utils
import flax.traverse_util as traverse_util
import jax
import jax.experimental
import jax.numpy as jnp
import numpy as np
import optax
import tqdm_loggable.auto as tqdm
import wandb
import numpy as np
import openpi.models.model as _model
import openpi.shared.array_typing as at
import openpi.shared.nnx_utils as nnx_utils
import openpi.training.checkpoints as _checkpoints
import openpi.training.config as _config
import openpi.training.data_loader as _data_loader
import openpi.training.optimizer as _optimizer
import openpi.training.sharding as sharding
import openpi.training.utils as training_utils
import openpi.training.weight_loaders as _weight_loaders
from flax.nnx import rnglib
from openpi.models.pi0_fast import Pi0FAST, make_attn_mask
@dataclasses.dataclass
class OftTrainingConfig:
"""openvla-oft"""
use_l1_regression: bool = False
use_diffusion: bool = True
use_discrete_tokens: bool = False
num_diffusion_steps_train: int = 25
diffusion_beta_start: float = 0.0001
diffusion_beta_end: float = 0.00005
grad_accumulation_steps: int = 1
use_val_set: bool = False
val_freq: int = 10_000
class DiffusionScheduler:
def __init__(self, num_train_timesteps: int, beta_start: float = 0.0001, beta_end: float = 0.02):
self.num_train_timesteps = num_train_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = jnp.cumprod(self.alphas)
self.alphas_cumprod_prev = jnp.concatenate([jnp.array([1.0]), self.alphas_cumprod[:-1]])
self.variance = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
self.variance = jnp.concatenate([jnp.array([0.0]), self.variance[1:]])
self.timesteps = jnp.arange(0, num_train_timesteps)
def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps
step_ratio = self.num_train_timesteps // num_inference_steps
self.timesteps = jnp.arange(0, self.num_train_timesteps, step_ratio)
def step(self, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray) -> Dict[str, jnp.ndarray]:
# DDIM step
alpha_cumprod = self.alphas_cumprod[timestep]
alpha_cumprod_prev = self.alphas_cumprod_prev[timestep]
# predict x_0
pred_original_sample = (sample - jnp.sqrt(1 - alpha_cumprod) * model_output) / jnp.sqrt(alpha_cumprod)
# predict x_{t-1}
pred_sample_direction = jnp.sqrt(1 - alpha_cumprod_prev) * model_output
prev_sample = jnp.sqrt(alpha_cumprod_prev) * pred_original_sample + pred_sample_direction
return {"prev_sample": prev_sample}
class TimeEncoder(nnx.Module):
def __init__(self, llm_dim: int, rngs: at.KeyArrayLike | None = None):
super().__init__()
self.llm_dim = llm_dim
if rngs is None:
rngs = jax.random.key(0)
rngs_obj = rnglib.Rngs(params=rngs)
self.time_embedding = nnx.Linear(1, llm_dim, rngs=rngs_obj)
self.time_mlp = nnx.Sequential(
nnx.Linear(llm_dim, llm_dim, rngs=rngs_obj),
nnx.relu,
nnx.Linear(llm_dim, llm_dim, rngs=rngs_obj),
)
def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray:
# timesteps: (batch_size,)
timesteps = timesteps.astype(jnp.float32)
time_emb = self.time_embedding(timesteps[:, None]) # (batch_size, llm_dim)
time_emb = self.time_mlp(time_emb)
return time_emb
class DiffusionActionHead(nnx.Module):
def __init__(self, input_dim: int, hidden_dim: int, action_dim: int, num_diffusion_steps: int, rngs: at.KeyArrayLike | None = None):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.action_dim = action_dim
self.num_diffusion_steps_train = num_diffusion_steps
if rngs is None:
rngs = jax.random.key(0)
rngs_obj = rnglib.Rngs(params=rngs)
# noise predictor
self.noise_predictor = nnx.Sequential(
nnx.Linear(input_dim, hidden_dim, rngs=rngs_obj),
nnx.relu,
nnx.Linear(hidden_dim, hidden_dim, rngs=rngs_obj),
nnx.relu,
nnx.Linear(hidden_dim, action_dim, rngs=rngs_obj),
)
# time encoder
self.time_encoder = TimeEncoder(hidden_dim, rngs=rngs)
# diffusion scheduler
self.noise_scheduler = DiffusionScheduler(num_diffusion_steps)
def sample_noisy_actions(self, actions: jnp.ndarray, rng: at.KeyArrayLike) -> Dict[str, jnp.ndarray]:
batch_size = actions.shape[0]
# sample timesteps
timesteps = jax.random.randint(rng, (batch_size,), 0, self.num_diffusion_steps_train)
# generate noise
noise = jax.random.normal(rng, actions.shape)
# add noise to actions
alpha_cumprod = self.noise_scheduler.alphas_cumprod[timesteps]
alpha_cumprod = alpha_cumprod.reshape(-1, 1, 1) # (batch_size, 1, 1)
noisy_actions = jnp.sqrt(alpha_cumprod) * actions + jnp.sqrt(1 - alpha_cumprod) * noise
# time step encoding
diffusion_timestep_embeddings = self.time_encoder(timesteps)
return {
"noise": noise,
"noisy_actions": noisy_actions,
"diffusion_timestep_embeddings": diffusion_timestep_embeddings,
"timesteps": timesteps,
}
def predict_noise(self, hidden_states: jnp.ndarray) -> jnp.ndarray:
return self.noise_predictor(hidden_states)
class NoisyActionProjector(nnx.Module):
def __init__(self, input_dim: int, llm_dim: int, rngs: at.KeyArrayLike | None = None):
super().__init__()
self.llm_dim = llm_dim
if rngs is None:
rngs = jax.random.key(0)
rngs_obj = rnglib.Rngs(params=rngs)
self.projection = nnx.Linear(input_dim, llm_dim, rngs=rngs_obj)
def __call__(self, noisy_actions: jnp.ndarray) -> jnp.ndarray:
return self.projection(noisy_actions)
def init_logging():
"""Custom logging format for better readability."""
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"}
class CustomFormatter(logging.Formatter):
def format(self, record):
record.levelname = level_mapping.get(record.levelname, record.levelname)
return super().format(record)
formatter = CustomFormatter(
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.handlers[0].setFormatter(formatter)
def init_wandb(config: _config.TrainConfig, oft_config: OftTrainingConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True):
if not enabled:
wandb.init(mode="disabled")
return
ckpt_dir = config.checkpoint_dir
if not ckpt_dir.exists():
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.")
if resuming:
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip()
wandb.init(id=run_id, resume="must", project=config.project_name)
else:
# openvla-oft run_id
run_id = f"{config.exp_name}+oft"
# LoRA
try:
if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant):
run_id += "+lora"
except:
pass
if config.ema_decay is None:
run_id += "+no_ema"
# training mode
if oft_config.use_l1_regression:
run_id += "+l1_regression"
if oft_config.use_diffusion:
run_id += "+diffusion"
if oft_config.use_discrete_tokens:
run_id += "+discrete"
wandb.init(
name=run_id,
config={
**dataclasses.asdict(config),
**dataclasses.asdict(oft_config)
},
project=config.project_name,
)
if wandb.run is not None:
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id)
if log_code and wandb.run is not None:
wandb.run.log_code(str(epath.Path(__file__).parent.parent))
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params:
"""Loads and validates the weights. Returns a loaded subset of the weights."""
loaded_params = loader.load(params_shape)
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True)
# Remove jax.ShapeDtypeStruct from the loaded params. This makes sure that only the loaded params are returned.
return traverse_util.unflatten_dict(
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)}
)
def apply_lora_to_model(model, config: _config.TrainConfig):
# LoRA
try:
if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant):
logging.info(f"Detected LoRA configuration: {config.model.paligemma_variant}")
return model
except:
pass
return model
def create_diffusion_components(config: _config.TrainConfig, oft_config: OftTrainingConfig, rng: at.KeyArrayLike):
if not oft_config.use_diffusion:
return None, None
llm_dim = 2048 # get from model config
action_dim = config.model.action_dim
action_horizon = config.model.action_horizon
# create diffusion action head
diffusion_action_head = DiffusionActionHead(
input_dim=llm_dim,
hidden_dim=llm_dim,
action_dim=action_dim,
num_diffusion_steps=oft_config.num_diffusion_steps_train,
rngs=rng
)
# create noisy action projector
noisy_action_projector = NoisyActionProjector(
input_dim=action_dim, # only use action_dim
llm_dim=llm_dim,
rngs=rng
)
return diffusion_action_head, noisy_action_projector
def lora_mask(tree):
def is_lora(path, v):
return any('lora' in str(p) for p in path)
return jax.tree_util.tree_map_with_path(lambda path, v: is_lora(path, v), tree)
@at.typecheck
def init_train_state(
config: _config.TrainConfig,
oft_config: OftTrainingConfig,
init_rng: at.KeyArrayLike,
mesh: jax.sharding.Mesh,
tx,
*,
resume: bool
) -> tuple[training_utils.TrainState, Any]:
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState:
rng, model_rng = jax.random.split(rng)
model = config.model.create(model_rng)
model = apply_lora_to_model(model, config)
diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, model_rng)
if partial_params is not None:
graphdef, state = nnx.split(model)
state.replace_by_pure_dict(partial_params)
model = nnx.merge(graphdef, state)
params = nnx.state(model)
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16)))
# use main tx
return training_utils.TrainState(
step=0,
params=params,
model_def=nnx.graphdef(model),
tx=tx,
opt_state=tx.init(params),
ema_decay=config.ema_decay,
ema_params=None if config.ema_decay is None else params,
)
train_state_shape = jax.eval_shape(init, init_rng)
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True)
if resume:
return train_state_shape, state_sharding
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict())
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
train_state = jax.jit(
init,
donate_argnums=(1,),
in_shardings=replicated_sharding,
out_shardings=state_sharding,
)(init_rng, partial_params)
return train_state, state_sharding
# TODO: modify L1 loss in the future
def compute_l1_loss(predicted_actions: jnp.ndarray, ground_truth_actions: jnp.ndarray) -> jnp.ndarray:
return jnp.mean(jnp.abs(predicted_actions - ground_truth_actions))
def compute_diffusion_loss(predicted_noise: jnp.ndarray, target_noise: jnp.ndarray) -> jnp.ndarray:
return jnp.mean((predicted_noise - target_noise) ** 2)
def run_diffusion_sampling(
model: _model.BaseModel,
diffusion_action_head: DiffusionActionHead,
noisy_action_projector: NoisyActionProjector,
observation: _model.Observation,
actions: _model.Actions,
rng: at.KeyArrayLike,
oft_config: OftTrainingConfig,
) -> jnp.ndarray:
"""diffusion sampling, main model and NoisyActionProjector are involved, adapt to Pi0FAST"""
batch_size = actions.shape[0]
action_dim = actions.shape[-1]
action_horizon = actions.shape[1]
# generate random noise as starting point
noise = jax.random.normal(rng, (batch_size, action_horizon, action_dim))
# set diffusion scheduler
diffusion_action_head.noise_scheduler.set_timesteps(oft_config.num_diffusion_steps_train)
curr_noisy_actions = noise
def diffusion_step(carry, timestep):
curr_noisy_actions = carry
timesteps = jnp.full((batch_size,), timestep)
# time step embedding
diffusion_timestep_embeddings = diffusion_action_head.time_encoder(timesteps) # (batch, llm_dim)
diffusion_timestep_embeddings = jnp.expand_dims(diffusion_timestep_embeddings, 1) # (batch, 1, llm_dim)
diffusion_timestep_embeddings = jnp.tile(diffusion_timestep_embeddings, (1, action_horizon, 1)) # (batch, action_horizon, llm_dim)
# Pi0FAST
if not isinstance(model, Pi0FAST):
raise ValueError("run_diffusion_sampling only supports Pi0FAST main model!")
obs_token_emb, input_mask, ar_mask = model.embed_inputs(observation) # (batch, obs_seq_len, llm_dim)
# embedding
noisy_action_emb = noisy_action_projector(curr_noisy_actions) # (batch, action_horizon, llm_dim)
full_emb = jnp.concatenate([obs_token_emb, noisy_action_emb, diffusion_timestep_embeddings], axis=1) # (batch, obs_seq_len+2*action_horizon, llm_dim)
# mask
full_input_mask = jnp.concatenate([input_mask, jnp.ones((batch_size, 2*action_horizon), dtype=input_mask.dtype)], axis=1)
full_ar_mask = jnp.concatenate([ar_mask, jnp.zeros((batch_size, 2*action_horizon), dtype=ar_mask.dtype)], axis=1)
attn_mask = make_attn_mask(full_input_mask, full_ar_mask)
attn_mask = attn_mask[:, None, :, :] # (batch, 1, seq_len, seq_len)
# hidden_states
hidden_states, _, _ = model.PaliGemma.llm(
embedded_prefix=full_emb,
mask=attn_mask,
return_prelogits=True,
)
obs_seq_len = obs_token_emb.shape[1]
actions_hidden_states = hidden_states[:, obs_seq_len:obs_seq_len+action_horizon, :] # (batch, action_horizon, llm_dim)
noise_pred = diffusion_action_head.predict_noise(actions_hidden_states) # (batch, action_horizon, action_dim)
prev_sample = diffusion_action_head.noise_scheduler.step(noise_pred, timestep, curr_noisy_actions)["prev_sample"]
return prev_sample, None
final_sample, _ = jax.lax.scan(diffusion_step, curr_noisy_actions, diffusion_action_head.noise_scheduler.timesteps)
return final_sample
def compute_loss_with_oft_modes(
model: _model.BaseModel,
rng: at.KeyArrayLike,
observation: _model.Observation,
actions: _model.Actions,
config: _config.TrainConfig,
oft_config: OftTrainingConfig,
diffusion_action_head: Optional[DiffusionActionHead] = None,
noisy_action_projector: Optional[NoisyActionProjector] = None,
train: bool = True
) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]:
"""openvla-oft"""
chunked_loss = model.compute_loss(rng, observation, actions, train=train)
base_loss = jnp.mean(chunked_loss)
metrics = {"loss": base_loss}
# calculate different losses based on training mode
if oft_config.use_discrete_tokens:
# discrete token prediction mode (default)
metrics["discrete_loss"] = base_loss
elif oft_config.use_l1_regression:
l1_loss = base_loss # TODO: calculate L1 loss
metrics["l1_loss"] = l1_loss
metrics["regression_loss"] = l1_loss
elif oft_config.use_diffusion and diffusion_action_head is not None:
# diffusion
batch_size = actions.shape[0]
action_horizon = actions.shape[1]
action_dim = actions.shape[2]
# sample noise
noisy_dict = diffusion_action_head.sample_noisy_actions(actions, rng)
noise = noisy_dict["noise"] # (batch, action_horizon, action_dim)
noisy_actions = noisy_dict["noisy_actions"] # (batch, action_horizon, action_dim)
diffusion_timestep_embeddings = noisy_dict["diffusion_timestep_embeddings"] # (batch, llm_dim)
timesteps = noisy_dict["timesteps"]
# hidden_states
if not isinstance(model, Pi0FAST):
raise ValueError("diffusion loss only supports Pi0FAST main model!")
if noisy_action_projector is None:
raise ValueError("diffusion loss needs noisy_action_projector, should not be None")
# noisy_action_projector
noisy_action_emb = noisy_action_projector(noisy_actions) # (batch, action_horizon, llm_dim)
# diffusion_timestep_embeddings -> (batch, action_horizon, llm_dim)
diffusion_timestep_embeddings = jnp.expand_dims(diffusion_timestep_embeddings, 1) # (batch, 1, llm_dim)
diffusion_timestep_embeddings = jnp.tile(diffusion_timestep_embeddings, (1, action_horizon, 1)) # (batch, action_horizon, llm_dim)
obs_token_emb, input_mask, ar_mask = model.embed_inputs(observation) # (batch, obs_seq_len, llm_dim)
full_emb = jnp.concatenate([obs_token_emb, noisy_action_emb, diffusion_timestep_embeddings], axis=1) # (batch, obs_seq_len+2*action_horizon, llm_dim)
full_input_mask = jnp.concatenate([input_mask, jnp.ones((batch_size, 2*action_horizon), dtype=input_mask.dtype)], axis=1)
full_ar_mask = jnp.concatenate([ar_mask, jnp.zeros((batch_size, 2*action_horizon), dtype=ar_mask.dtype)], axis=1)
attn_mask = make_attn_mask(full_input_mask, full_ar_mask)
attn_mask = attn_mask[:, None, :, :] # (batch, 1, seq_len, seq_len)
hidden_states, _, _ = model.PaliGemma.llm(
embedded_prefix=full_emb,
mask=attn_mask,
return_prelogits=True,
)
obs_seq_len = obs_token_emb.shape[1]
# actions_hidden_state
actions_hidden_states = hidden_states[:, obs_seq_len:obs_seq_len+action_horizon, :] # (batch, action_horizon, llm_dim)
predicted_noise = diffusion_action_head.predict_noise(actions_hidden_states) # (batch, action_horizon, action_dim)
# loss
diffusion_loss = jnp.mean((predicted_noise - noise) ** 2)
metrics["diffusion_loss"] = diffusion_loss
metrics["noise_prediction_loss"] = diffusion_loss
base_loss = diffusion_loss
# LoRA
try:
if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant):
metrics["lora_loss"] = base_loss
metrics["finetune_mode"] = jnp.array(1.0) # mark as finetune mode
except:
pass
return base_loss, metrics
@at.typecheck
def train_step(
config: _config.TrainConfig,
oft_config: OftTrainingConfig,
rng: at.KeyArrayLike,
state: training_utils.TrainState,
batch: tuple[_model.Observation, _model.Actions],
) -> tuple[training_utils.TrainState, dict[str, at.Array]]:
model = nnx.merge(state.model_def, state.params)
model.train()
train_rng = jax.random.fold_in(rng, state.step)
observation, actions = batch
diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, train_rng)
# openvla-oft loss
loss, metrics = compute_loss_with_oft_modes(
model, train_rng, observation, actions, config, oft_config,
diffusion_action_head, noisy_action_projector, train=True
)
# Filter out frozen params.
diff_state = nnx.DiffState(0, config.trainable_filter)
grads = nnx.grad(lambda m, r, obs, acts: compute_loss_with_oft_modes(
m, r, obs, acts, config, oft_config, diffusion_action_head, noisy_action_projector, train=True
)[0])(model, train_rng, observation, actions)
params = state.params
#print(params)
updates, new_opt_state = state.tx.update(grads, state.opt_state, params)
new_params = optax.apply_updates(params, updates)
# Update the model in place and return the new full state.
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state)
if state.ema_decay is not None and state.ema_params is not None:
ema_decay = state.ema_decay
new_state = dataclasses.replace(
new_state,
ema_params=jax.tree.map(
lambda old, new: ema_decay * old + (1 - ema_decay) * new, state.ema_params, new_params
),
)
# Filter out params that aren't kernels.
kernel_params = nnx.state(
model,
nnx.All(
nnx.Param,
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")),
lambda _, x: x.value.ndim > 1,
),
)
info = {
**metrics,
"grad_norm": optax.global_norm(grads),
"param_norm": optax.global_norm(kernel_params),
}
# sample actions for visualization/debug
if diffusion_action_head is not None and noisy_action_projector is not None:
sampled_actions = run_diffusion_sampling(
model, diffusion_action_head, noisy_action_projector, observation, actions, rng, oft_config
)
# only take the first batch element, avoid info too large
info["sampled_actions"] = sampled_actions[:1]
return new_state, info
def run_validation(
config: _config.TrainConfig,
oft_config: OftTrainingConfig,
state: training_utils.TrainState,
val_data_loader,
mesh: jax.sharding.Mesh,
step: int,
) -> Dict[str, float]:
"""validation"""
if not oft_config.use_val_set:
return {}
model = nnx.merge(state.model_def, state.params)
model.eval()
val_metrics = []
val_batches = 0
for batch in val_data_loader:
if val_batches >= 10: # limit validation batches
break
observation, actions = batch
# create diffusion components
diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, jax.random.key(0))
loss, metrics = compute_loss_with_oft_modes(
model, jax.random.key(0), observation, actions, config, oft_config,
diffusion_action_head, noisy_action_projector, train=False
)
val_metrics.append(metrics)
val_batches += 1
# calculate average metrics
avg_metrics = {}
if val_metrics:
for key in val_metrics[0].keys():
avg_metrics[f"val_{key}"] = jnp.mean(jnp.array([m[key] for m in val_metrics]))
return avg_metrics
def main(config: _config.TrainConfig):
init_logging()
logging.info(f"Running on: {platform.node()}")
logging.info(f"Using openvla-oft enhanced training script")
logging.info(f"Config: {config.name}")
# openvla-oft config
oft_config = OftTrainingConfig()
if config.batch_size % jax.device_count() != 0:
raise ValueError(
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}."
)
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser()))
rng = jax.random.key(config.seed)
train_rng, init_rng = jax.random.split(rng)
mesh = sharding.make_mesh(config.fsdp_devices)
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS))
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir(
str(config.checkpoint_dir),
keep_period=config.keep_period,
overwrite=config.overwrite,
resume=config.resume,
)
init_wandb(config, oft_config, resuming=resuming, enabled=config.wandb_enabled)
data_loader = _data_loader.create_data_loader(
config,
sharding=data_sharding,
shuffle=True,
)
data_iter = iter(data_loader)
batch = next(data_iter)
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}")
# Log images from first batch to sanity check.
images_to_log = [
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1))
for i in range(min(5, len(next(iter(batch[0].images.values())))))
]
wandb.log({"camera_views": images_to_log}, step=0)
# initialize model, get all params (only for generating mask)
model = config.model.create(init_rng)
model = apply_lora_to_model(model, config)
params = nnx.state(model)
mask = lora_mask(params)
# add gradient clipping, clip_norm=1.0
tx = optax.chain(
optax.clip_by_global_norm(1.0),
optax.masked(
_optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None),
mask
)
)
train_state, train_state_sharding = init_train_state(
config, oft_config, init_rng, mesh, tx=tx, resume=resuming
)
jax.block_until_ready(train_state)
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}")
if resuming:
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader)
ptrain_step = jax.jit(
functools.partial(train_step, config, oft_config),
in_shardings=(replicated_sharding, train_state_sharding, data_sharding),
out_shardings=(train_state_sharding, replicated_sharding),
donate_argnums=(1,),
)
start_step = int(jax.device_get(train_state.step))
pbar = tqdm.tqdm(
range(start_step, config.num_train_steps),
initial=start_step,
total=config.num_train_steps,
dynamic_ncols=True,
)
infos = []
gradient_step = 0
for step in pbar:
with sharding.set_mesh(mesh):
train_state, info = ptrain_step(train_rng, train_state, batch)
infos.append(info)
if (step + 1) % oft_config.grad_accumulation_steps == 0:
gradient_step += 1
if gradient_step % config.log_interval == 0:
stacked_infos = common_utils.stack_forest(infos)
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos))
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items())
pbar.write(f"Step {step}: {info_str}")
wandb.log(reduced_info, step=step)
infos = []
# validation
if oft_config.use_val_set and gradient_step % oft_config.val_freq == 0:
val_metrics = run_validation(config, oft_config, train_state, data_loader, mesh, step)
if val_metrics:
wandb.log(val_metrics, step=step)
pbar.write(f"Validation at step {step}: {val_metrics}")
batch = next(data_iter)
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1:
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step)
logging.info("Waiting for checkpoint manager to finish")
checkpoint_manager.wait_until_finished()
if __name__ == "__main__":
main(_config.cli())