|
import dataclasses |
|
import functools |
|
import logging |
|
import platform |
|
from typing import Any, Optional, Dict, Tuple |
|
|
|
import etils.epath as epath |
|
import flax.nnx as nnx |
|
from flax.training import common_utils |
|
import flax.traverse_util as traverse_util |
|
import jax |
|
import jax.experimental |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import optax |
|
import tqdm_loggable.auto as tqdm |
|
import wandb |
|
import numpy as np |
|
|
|
import openpi.models.model as _model |
|
import openpi.shared.array_typing as at |
|
import openpi.shared.nnx_utils as nnx_utils |
|
import openpi.training.checkpoints as _checkpoints |
|
import openpi.training.config as _config |
|
import openpi.training.data_loader as _data_loader |
|
import openpi.training.optimizer as _optimizer |
|
import openpi.training.sharding as sharding |
|
import openpi.training.utils as training_utils |
|
import openpi.training.weight_loaders as _weight_loaders |
|
from flax.nnx import rnglib |
|
from openpi.models.pi0_fast import Pi0FAST, make_attn_mask |
|
|
|
|
|
@dataclasses.dataclass |
|
class OftTrainingConfig: |
|
"""openvla-oft""" |
|
|
|
use_l1_regression: bool = False |
|
use_diffusion: bool = True |
|
use_discrete_tokens: bool = False |
|
|
|
num_diffusion_steps_train: int = 25 |
|
diffusion_beta_start: float = 0.0001 |
|
diffusion_beta_end: float = 0.00005 |
|
|
|
grad_accumulation_steps: int = 1 |
|
|
|
use_val_set: bool = False |
|
val_freq: int = 10_000 |
|
|
|
|
|
class DiffusionScheduler: |
|
|
|
def __init__(self, num_train_timesteps: int, beta_start: float = 0.0001, beta_end: float = 0.02): |
|
self.num_train_timesteps = num_train_timesteps |
|
self.beta_start = beta_start |
|
self.beta_end = beta_end |
|
|
|
self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps) |
|
self.alphas = 1.0 - self.betas |
|
self.alphas_cumprod = jnp.cumprod(self.alphas) |
|
self.alphas_cumprod_prev = jnp.concatenate([jnp.array([1.0]), self.alphas_cumprod[:-1]]) |
|
|
|
self.variance = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) |
|
self.variance = jnp.concatenate([jnp.array([0.0]), self.variance[1:]]) |
|
|
|
self.timesteps = jnp.arange(0, num_train_timesteps) |
|
|
|
def set_timesteps(self, num_inference_steps: int): |
|
self.num_inference_steps = num_inference_steps |
|
step_ratio = self.num_train_timesteps // num_inference_steps |
|
self.timesteps = jnp.arange(0, self.num_train_timesteps, step_ratio) |
|
|
|
def step(self, model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray) -> Dict[str, jnp.ndarray]: |
|
|
|
alpha_cumprod = self.alphas_cumprod[timestep] |
|
alpha_cumprod_prev = self.alphas_cumprod_prev[timestep] |
|
|
|
|
|
pred_original_sample = (sample - jnp.sqrt(1 - alpha_cumprod) * model_output) / jnp.sqrt(alpha_cumprod) |
|
|
|
|
|
pred_sample_direction = jnp.sqrt(1 - alpha_cumprod_prev) * model_output |
|
prev_sample = jnp.sqrt(alpha_cumprod_prev) * pred_original_sample + pred_sample_direction |
|
|
|
return {"prev_sample": prev_sample} |
|
|
|
|
|
class TimeEncoder(nnx.Module): |
|
|
|
def __init__(self, llm_dim: int, rngs: at.KeyArrayLike | None = None): |
|
super().__init__() |
|
self.llm_dim = llm_dim |
|
if rngs is None: |
|
rngs = jax.random.key(0) |
|
rngs_obj = rnglib.Rngs(params=rngs) |
|
self.time_embedding = nnx.Linear(1, llm_dim, rngs=rngs_obj) |
|
self.time_mlp = nnx.Sequential( |
|
nnx.Linear(llm_dim, llm_dim, rngs=rngs_obj), |
|
nnx.relu, |
|
nnx.Linear(llm_dim, llm_dim, rngs=rngs_obj), |
|
) |
|
|
|
def __call__(self, timesteps: jnp.ndarray) -> jnp.ndarray: |
|
|
|
timesteps = timesteps.astype(jnp.float32) |
|
time_emb = self.time_embedding(timesteps[:, None]) |
|
time_emb = self.time_mlp(time_emb) |
|
return time_emb |
|
|
|
|
|
class DiffusionActionHead(nnx.Module): |
|
|
|
def __init__(self, input_dim: int, hidden_dim: int, action_dim: int, num_diffusion_steps: int, rngs: at.KeyArrayLike | None = None): |
|
super().__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.action_dim = action_dim |
|
self.num_diffusion_steps_train = num_diffusion_steps |
|
|
|
if rngs is None: |
|
rngs = jax.random.key(0) |
|
rngs_obj = rnglib.Rngs(params=rngs) |
|
|
|
|
|
self.noise_predictor = nnx.Sequential( |
|
nnx.Linear(input_dim, hidden_dim, rngs=rngs_obj), |
|
nnx.relu, |
|
nnx.Linear(hidden_dim, hidden_dim, rngs=rngs_obj), |
|
nnx.relu, |
|
nnx.Linear(hidden_dim, action_dim, rngs=rngs_obj), |
|
) |
|
|
|
|
|
self.time_encoder = TimeEncoder(hidden_dim, rngs=rngs) |
|
|
|
|
|
self.noise_scheduler = DiffusionScheduler(num_diffusion_steps) |
|
|
|
def sample_noisy_actions(self, actions: jnp.ndarray, rng: at.KeyArrayLike) -> Dict[str, jnp.ndarray]: |
|
batch_size = actions.shape[0] |
|
|
|
|
|
timesteps = jax.random.randint(rng, (batch_size,), 0, self.num_diffusion_steps_train) |
|
|
|
|
|
noise = jax.random.normal(rng, actions.shape) |
|
|
|
|
|
alpha_cumprod = self.noise_scheduler.alphas_cumprod[timesteps] |
|
alpha_cumprod = alpha_cumprod.reshape(-1, 1, 1) |
|
|
|
noisy_actions = jnp.sqrt(alpha_cumprod) * actions + jnp.sqrt(1 - alpha_cumprod) * noise |
|
|
|
|
|
diffusion_timestep_embeddings = self.time_encoder(timesteps) |
|
|
|
return { |
|
"noise": noise, |
|
"noisy_actions": noisy_actions, |
|
"diffusion_timestep_embeddings": diffusion_timestep_embeddings, |
|
"timesteps": timesteps, |
|
} |
|
|
|
def predict_noise(self, hidden_states: jnp.ndarray) -> jnp.ndarray: |
|
return self.noise_predictor(hidden_states) |
|
|
|
|
|
class NoisyActionProjector(nnx.Module): |
|
|
|
def __init__(self, input_dim: int, llm_dim: int, rngs: at.KeyArrayLike | None = None): |
|
super().__init__() |
|
self.llm_dim = llm_dim |
|
if rngs is None: |
|
rngs = jax.random.key(0) |
|
rngs_obj = rnglib.Rngs(params=rngs) |
|
self.projection = nnx.Linear(input_dim, llm_dim, rngs=rngs_obj) |
|
|
|
def __call__(self, noisy_actions: jnp.ndarray) -> jnp.ndarray: |
|
return self.projection(noisy_actions) |
|
|
|
|
|
def init_logging(): |
|
"""Custom logging format for better readability.""" |
|
level_mapping = {"DEBUG": "D", "INFO": "I", "WARNING": "W", "ERROR": "E", "CRITICAL": "C"} |
|
|
|
class CustomFormatter(logging.Formatter): |
|
def format(self, record): |
|
record.levelname = level_mapping.get(record.levelname, record.levelname) |
|
return super().format(record) |
|
|
|
formatter = CustomFormatter( |
|
fmt="%(asctime)s.%(msecs)03d [%(levelname)s] %(message)-80s (%(process)d:%(filename)s:%(lineno)s)", |
|
datefmt="%H:%M:%S", |
|
) |
|
|
|
logger = logging.getLogger() |
|
logger.setLevel(logging.INFO) |
|
logger.handlers[0].setFormatter(formatter) |
|
|
|
|
|
def init_wandb(config: _config.TrainConfig, oft_config: OftTrainingConfig, *, resuming: bool, log_code: bool = False, enabled: bool = True): |
|
if not enabled: |
|
wandb.init(mode="disabled") |
|
return |
|
|
|
ckpt_dir = config.checkpoint_dir |
|
if not ckpt_dir.exists(): |
|
raise FileNotFoundError(f"Checkpoint directory {ckpt_dir} does not exist.") |
|
|
|
if resuming: |
|
run_id = (ckpt_dir / "wandb_id.txt").read_text().strip() |
|
wandb.init(id=run_id, resume="must", project=config.project_name) |
|
else: |
|
|
|
run_id = f"{config.exp_name}+oft" |
|
|
|
|
|
try: |
|
if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant): |
|
run_id += "+lora" |
|
except: |
|
pass |
|
if config.ema_decay is None: |
|
run_id += "+no_ema" |
|
|
|
|
|
if oft_config.use_l1_regression: |
|
run_id += "+l1_regression" |
|
if oft_config.use_diffusion: |
|
run_id += "+diffusion" |
|
if oft_config.use_discrete_tokens: |
|
run_id += "+discrete" |
|
|
|
wandb.init( |
|
name=run_id, |
|
config={ |
|
**dataclasses.asdict(config), |
|
**dataclasses.asdict(oft_config) |
|
}, |
|
project=config.project_name, |
|
) |
|
if wandb.run is not None: |
|
(ckpt_dir / "wandb_id.txt").write_text(wandb.run.id) |
|
|
|
if log_code and wandb.run is not None: |
|
wandb.run.log_code(str(epath.Path(__file__).parent.parent)) |
|
|
|
|
|
def _load_weights_and_validate(loader: _weight_loaders.WeightLoader, params_shape: at.Params) -> at.Params: |
|
"""Loads and validates the weights. Returns a loaded subset of the weights.""" |
|
loaded_params = loader.load(params_shape) |
|
at.check_pytree_equality(expected=params_shape, got=loaded_params, check_shapes=True, check_dtypes=True) |
|
|
|
|
|
return traverse_util.unflatten_dict( |
|
{k: v for k, v in traverse_util.flatten_dict(loaded_params).items() if not isinstance(v, jax.ShapeDtypeStruct)} |
|
) |
|
|
|
|
|
def apply_lora_to_model(model, config: _config.TrainConfig): |
|
|
|
try: |
|
if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant): |
|
logging.info(f"Detected LoRA configuration: {config.model.paligemma_variant}") |
|
return model |
|
except: |
|
pass |
|
|
|
return model |
|
|
|
|
|
def create_diffusion_components(config: _config.TrainConfig, oft_config: OftTrainingConfig, rng: at.KeyArrayLike): |
|
if not oft_config.use_diffusion: |
|
return None, None |
|
|
|
llm_dim = 2048 |
|
action_dim = config.model.action_dim |
|
action_horizon = config.model.action_horizon |
|
|
|
|
|
diffusion_action_head = DiffusionActionHead( |
|
input_dim=llm_dim, |
|
hidden_dim=llm_dim, |
|
action_dim=action_dim, |
|
num_diffusion_steps=oft_config.num_diffusion_steps_train, |
|
rngs=rng |
|
) |
|
|
|
|
|
noisy_action_projector = NoisyActionProjector( |
|
input_dim=action_dim, |
|
llm_dim=llm_dim, |
|
rngs=rng |
|
) |
|
|
|
return diffusion_action_head, noisy_action_projector |
|
|
|
|
|
def lora_mask(tree): |
|
def is_lora(path, v): |
|
return any('lora' in str(p) for p in path) |
|
return jax.tree_util.tree_map_with_path(lambda path, v: is_lora(path, v), tree) |
|
|
|
|
|
@at.typecheck |
|
def init_train_state( |
|
config: _config.TrainConfig, |
|
oft_config: OftTrainingConfig, |
|
init_rng: at.KeyArrayLike, |
|
mesh: jax.sharding.Mesh, |
|
tx, |
|
*, |
|
resume: bool |
|
) -> tuple[training_utils.TrainState, Any]: |
|
def init(rng: at.KeyArrayLike, partial_params: at.Params | None = None) -> training_utils.TrainState: |
|
rng, model_rng = jax.random.split(rng) |
|
model = config.model.create(model_rng) |
|
model = apply_lora_to_model(model, config) |
|
diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, model_rng) |
|
if partial_params is not None: |
|
graphdef, state = nnx.split(model) |
|
state.replace_by_pure_dict(partial_params) |
|
model = nnx.merge(graphdef, state) |
|
params = nnx.state(model) |
|
params = nnx_utils.state_map(params, config.freeze_filter, lambda p: p.replace(p.value.astype(jnp.bfloat16))) |
|
|
|
return training_utils.TrainState( |
|
step=0, |
|
params=params, |
|
model_def=nnx.graphdef(model), |
|
tx=tx, |
|
opt_state=tx.init(params), |
|
ema_decay=config.ema_decay, |
|
ema_params=None if config.ema_decay is None else params, |
|
) |
|
train_state_shape = jax.eval_shape(init, init_rng) |
|
state_sharding = sharding.fsdp_sharding(train_state_shape, mesh, log=True) |
|
if resume: |
|
return train_state_shape, state_sharding |
|
partial_params = _load_weights_and_validate(config.weight_loader, train_state_shape.params.to_pure_dict()) |
|
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) |
|
train_state = jax.jit( |
|
init, |
|
donate_argnums=(1,), |
|
in_shardings=replicated_sharding, |
|
out_shardings=state_sharding, |
|
)(init_rng, partial_params) |
|
return train_state, state_sharding |
|
|
|
|
|
def compute_l1_loss(predicted_actions: jnp.ndarray, ground_truth_actions: jnp.ndarray) -> jnp.ndarray: |
|
return jnp.mean(jnp.abs(predicted_actions - ground_truth_actions)) |
|
|
|
|
|
def compute_diffusion_loss(predicted_noise: jnp.ndarray, target_noise: jnp.ndarray) -> jnp.ndarray: |
|
return jnp.mean((predicted_noise - target_noise) ** 2) |
|
|
|
|
|
def run_diffusion_sampling( |
|
model: _model.BaseModel, |
|
diffusion_action_head: DiffusionActionHead, |
|
noisy_action_projector: NoisyActionProjector, |
|
observation: _model.Observation, |
|
actions: _model.Actions, |
|
rng: at.KeyArrayLike, |
|
oft_config: OftTrainingConfig, |
|
) -> jnp.ndarray: |
|
"""diffusion sampling, main model and NoisyActionProjector are involved, adapt to Pi0FAST""" |
|
batch_size = actions.shape[0] |
|
action_dim = actions.shape[-1] |
|
action_horizon = actions.shape[1] |
|
|
|
|
|
noise = jax.random.normal(rng, (batch_size, action_horizon, action_dim)) |
|
|
|
|
|
diffusion_action_head.noise_scheduler.set_timesteps(oft_config.num_diffusion_steps_train) |
|
|
|
curr_noisy_actions = noise |
|
|
|
def diffusion_step(carry, timestep): |
|
curr_noisy_actions = carry |
|
timesteps = jnp.full((batch_size,), timestep) |
|
|
|
diffusion_timestep_embeddings = diffusion_action_head.time_encoder(timesteps) |
|
diffusion_timestep_embeddings = jnp.expand_dims(diffusion_timestep_embeddings, 1) |
|
diffusion_timestep_embeddings = jnp.tile(diffusion_timestep_embeddings, (1, action_horizon, 1)) |
|
|
|
|
|
if not isinstance(model, Pi0FAST): |
|
raise ValueError("run_diffusion_sampling only supports Pi0FAST main model!") |
|
obs_token_emb, input_mask, ar_mask = model.embed_inputs(observation) |
|
|
|
noisy_action_emb = noisy_action_projector(curr_noisy_actions) |
|
|
|
full_emb = jnp.concatenate([obs_token_emb, noisy_action_emb, diffusion_timestep_embeddings], axis=1) |
|
|
|
|
|
full_input_mask = jnp.concatenate([input_mask, jnp.ones((batch_size, 2*action_horizon), dtype=input_mask.dtype)], axis=1) |
|
full_ar_mask = jnp.concatenate([ar_mask, jnp.zeros((batch_size, 2*action_horizon), dtype=ar_mask.dtype)], axis=1) |
|
attn_mask = make_attn_mask(full_input_mask, full_ar_mask) |
|
attn_mask = attn_mask[:, None, :, :] |
|
|
|
|
|
hidden_states, _, _ = model.PaliGemma.llm( |
|
embedded_prefix=full_emb, |
|
mask=attn_mask, |
|
return_prelogits=True, |
|
) |
|
obs_seq_len = obs_token_emb.shape[1] |
|
|
|
actions_hidden_states = hidden_states[:, obs_seq_len:obs_seq_len+action_horizon, :] |
|
noise_pred = diffusion_action_head.predict_noise(actions_hidden_states) |
|
|
|
prev_sample = diffusion_action_head.noise_scheduler.step(noise_pred, timestep, curr_noisy_actions)["prev_sample"] |
|
return prev_sample, None |
|
|
|
final_sample, _ = jax.lax.scan(diffusion_step, curr_noisy_actions, diffusion_action_head.noise_scheduler.timesteps) |
|
|
|
return final_sample |
|
|
|
|
|
def compute_loss_with_oft_modes( |
|
model: _model.BaseModel, |
|
rng: at.KeyArrayLike, |
|
observation: _model.Observation, |
|
actions: _model.Actions, |
|
config: _config.TrainConfig, |
|
oft_config: OftTrainingConfig, |
|
diffusion_action_head: Optional[DiffusionActionHead] = None, |
|
noisy_action_projector: Optional[NoisyActionProjector] = None, |
|
train: bool = True |
|
) -> Tuple[jnp.ndarray, Dict[str, jnp.ndarray]]: |
|
"""openvla-oft""" |
|
|
|
chunked_loss = model.compute_loss(rng, observation, actions, train=train) |
|
base_loss = jnp.mean(chunked_loss) |
|
|
|
metrics = {"loss": base_loss} |
|
|
|
|
|
if oft_config.use_discrete_tokens: |
|
|
|
metrics["discrete_loss"] = base_loss |
|
|
|
elif oft_config.use_l1_regression: |
|
l1_loss = base_loss |
|
metrics["l1_loss"] = l1_loss |
|
metrics["regression_loss"] = l1_loss |
|
|
|
elif oft_config.use_diffusion and diffusion_action_head is not None: |
|
|
|
batch_size = actions.shape[0] |
|
action_horizon = actions.shape[1] |
|
action_dim = actions.shape[2] |
|
|
|
noisy_dict = diffusion_action_head.sample_noisy_actions(actions, rng) |
|
noise = noisy_dict["noise"] |
|
noisy_actions = noisy_dict["noisy_actions"] |
|
diffusion_timestep_embeddings = noisy_dict["diffusion_timestep_embeddings"] |
|
timesteps = noisy_dict["timesteps"] |
|
|
|
if not isinstance(model, Pi0FAST): |
|
raise ValueError("diffusion loss only supports Pi0FAST main model!") |
|
if noisy_action_projector is None: |
|
raise ValueError("diffusion loss needs noisy_action_projector, should not be None") |
|
|
|
noisy_action_emb = noisy_action_projector(noisy_actions) |
|
|
|
diffusion_timestep_embeddings = jnp.expand_dims(diffusion_timestep_embeddings, 1) |
|
diffusion_timestep_embeddings = jnp.tile(diffusion_timestep_embeddings, (1, action_horizon, 1)) |
|
obs_token_emb, input_mask, ar_mask = model.embed_inputs(observation) |
|
|
|
full_emb = jnp.concatenate([obs_token_emb, noisy_action_emb, diffusion_timestep_embeddings], axis=1) |
|
full_input_mask = jnp.concatenate([input_mask, jnp.ones((batch_size, 2*action_horizon), dtype=input_mask.dtype)], axis=1) |
|
full_ar_mask = jnp.concatenate([ar_mask, jnp.zeros((batch_size, 2*action_horizon), dtype=ar_mask.dtype)], axis=1) |
|
attn_mask = make_attn_mask(full_input_mask, full_ar_mask) |
|
attn_mask = attn_mask[:, None, :, :] |
|
hidden_states, _, _ = model.PaliGemma.llm( |
|
embedded_prefix=full_emb, |
|
mask=attn_mask, |
|
return_prelogits=True, |
|
) |
|
obs_seq_len = obs_token_emb.shape[1] |
|
|
|
actions_hidden_states = hidden_states[:, obs_seq_len:obs_seq_len+action_horizon, :] |
|
predicted_noise = diffusion_action_head.predict_noise(actions_hidden_states) |
|
|
|
diffusion_loss = jnp.mean((predicted_noise - noise) ** 2) |
|
metrics["diffusion_loss"] = diffusion_loss |
|
metrics["noise_prediction_loss"] = diffusion_loss |
|
base_loss = diffusion_loss |
|
|
|
|
|
try: |
|
if hasattr(config.model, 'paligemma_variant') and 'lora' in str(config.model.paligemma_variant): |
|
metrics["lora_loss"] = base_loss |
|
metrics["finetune_mode"] = jnp.array(1.0) |
|
except: |
|
pass |
|
|
|
return base_loss, metrics |
|
|
|
|
|
@at.typecheck |
|
def train_step( |
|
config: _config.TrainConfig, |
|
oft_config: OftTrainingConfig, |
|
rng: at.KeyArrayLike, |
|
state: training_utils.TrainState, |
|
batch: tuple[_model.Observation, _model.Actions], |
|
) -> tuple[training_utils.TrainState, dict[str, at.Array]]: |
|
model = nnx.merge(state.model_def, state.params) |
|
model.train() |
|
|
|
train_rng = jax.random.fold_in(rng, state.step) |
|
observation, actions = batch |
|
|
|
diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, train_rng) |
|
|
|
|
|
loss, metrics = compute_loss_with_oft_modes( |
|
model, train_rng, observation, actions, config, oft_config, |
|
diffusion_action_head, noisy_action_projector, train=True |
|
) |
|
|
|
|
|
diff_state = nnx.DiffState(0, config.trainable_filter) |
|
grads = nnx.grad(lambda m, r, obs, acts: compute_loss_with_oft_modes( |
|
m, r, obs, acts, config, oft_config, diffusion_action_head, noisy_action_projector, train=True |
|
)[0])(model, train_rng, observation, actions) |
|
|
|
params = state.params |
|
|
|
updates, new_opt_state = state.tx.update(grads, state.opt_state, params) |
|
new_params = optax.apply_updates(params, updates) |
|
|
|
|
|
new_state = dataclasses.replace(state, step=state.step + 1, params=new_params, opt_state=new_opt_state) |
|
if state.ema_decay is not None and state.ema_params is not None: |
|
ema_decay = state.ema_decay |
|
new_state = dataclasses.replace( |
|
new_state, |
|
ema_params=jax.tree.map( |
|
lambda old, new: ema_decay * old + (1 - ema_decay) * new, state.ema_params, new_params |
|
), |
|
) |
|
|
|
|
|
kernel_params = nnx.state( |
|
model, |
|
nnx.All( |
|
nnx.Param, |
|
nnx.Not(nnx_utils.PathRegex(".*/(bias|scale|pos_embedding|input_embedding)")), |
|
lambda _, x: x.value.ndim > 1, |
|
), |
|
) |
|
|
|
info = { |
|
**metrics, |
|
"grad_norm": optax.global_norm(grads), |
|
"param_norm": optax.global_norm(kernel_params), |
|
} |
|
|
|
|
|
if diffusion_action_head is not None and noisy_action_projector is not None: |
|
sampled_actions = run_diffusion_sampling( |
|
model, diffusion_action_head, noisy_action_projector, observation, actions, rng, oft_config |
|
) |
|
|
|
info["sampled_actions"] = sampled_actions[:1] |
|
|
|
return new_state, info |
|
|
|
|
|
def run_validation( |
|
config: _config.TrainConfig, |
|
oft_config: OftTrainingConfig, |
|
state: training_utils.TrainState, |
|
val_data_loader, |
|
mesh: jax.sharding.Mesh, |
|
step: int, |
|
) -> Dict[str, float]: |
|
"""validation""" |
|
if not oft_config.use_val_set: |
|
return {} |
|
|
|
model = nnx.merge(state.model_def, state.params) |
|
model.eval() |
|
|
|
val_metrics = [] |
|
val_batches = 0 |
|
|
|
for batch in val_data_loader: |
|
if val_batches >= 10: |
|
break |
|
|
|
observation, actions = batch |
|
|
|
|
|
diffusion_action_head, noisy_action_projector = create_diffusion_components(config, oft_config, jax.random.key(0)) |
|
|
|
loss, metrics = compute_loss_with_oft_modes( |
|
model, jax.random.key(0), observation, actions, config, oft_config, |
|
diffusion_action_head, noisy_action_projector, train=False |
|
) |
|
|
|
val_metrics.append(metrics) |
|
val_batches += 1 |
|
|
|
|
|
avg_metrics = {} |
|
if val_metrics: |
|
for key in val_metrics[0].keys(): |
|
avg_metrics[f"val_{key}"] = jnp.mean(jnp.array([m[key] for m in val_metrics])) |
|
|
|
return avg_metrics |
|
|
|
|
|
def main(config: _config.TrainConfig): |
|
init_logging() |
|
logging.info(f"Running on: {platform.node()}") |
|
logging.info(f"Using openvla-oft enhanced training script") |
|
logging.info(f"Config: {config.name}") |
|
|
|
|
|
oft_config = OftTrainingConfig() |
|
|
|
if config.batch_size % jax.device_count() != 0: |
|
raise ValueError( |
|
f"Batch size {config.batch_size} must be divisible by the number of devices {jax.device_count()}." |
|
) |
|
|
|
jax.config.update("jax_compilation_cache_dir", str(epath.Path("~/.cache/jax").expanduser())) |
|
|
|
rng = jax.random.key(config.seed) |
|
train_rng, init_rng = jax.random.split(rng) |
|
|
|
mesh = sharding.make_mesh(config.fsdp_devices) |
|
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(sharding.DATA_AXIS)) |
|
replicated_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) |
|
|
|
checkpoint_manager, resuming = _checkpoints.initialize_checkpoint_dir( |
|
str(config.checkpoint_dir), |
|
keep_period=config.keep_period, |
|
overwrite=config.overwrite, |
|
resume=config.resume, |
|
) |
|
init_wandb(config, oft_config, resuming=resuming, enabled=config.wandb_enabled) |
|
|
|
data_loader = _data_loader.create_data_loader( |
|
config, |
|
sharding=data_sharding, |
|
shuffle=True, |
|
) |
|
data_iter = iter(data_loader) |
|
batch = next(data_iter) |
|
logging.info(f"Initialized data loader:\n{training_utils.array_tree_to_info(batch)}") |
|
|
|
|
|
images_to_log = [ |
|
wandb.Image(np.concatenate([np.array(img[i]) for img in batch[0].images.values()], axis=1)) |
|
for i in range(min(5, len(next(iter(batch[0].images.values()))))) |
|
] |
|
wandb.log({"camera_views": images_to_log}, step=0) |
|
|
|
|
|
model = config.model.create(init_rng) |
|
model = apply_lora_to_model(model, config) |
|
params = nnx.state(model) |
|
mask = lora_mask(params) |
|
|
|
tx = optax.chain( |
|
optax.clip_by_global_norm(1.0), |
|
optax.masked( |
|
_optimizer.create_optimizer(config.optimizer, config.lr_schedule, weight_decay_mask=None), |
|
mask |
|
) |
|
) |
|
|
|
train_state, train_state_sharding = init_train_state( |
|
config, oft_config, init_rng, mesh, tx=tx, resume=resuming |
|
) |
|
jax.block_until_ready(train_state) |
|
logging.info(f"Initialized train state:\n{training_utils.array_tree_to_info(train_state.params)}") |
|
|
|
if resuming: |
|
train_state = _checkpoints.restore_state(checkpoint_manager, train_state, data_loader) |
|
|
|
ptrain_step = jax.jit( |
|
functools.partial(train_step, config, oft_config), |
|
in_shardings=(replicated_sharding, train_state_sharding, data_sharding), |
|
out_shardings=(train_state_sharding, replicated_sharding), |
|
donate_argnums=(1,), |
|
) |
|
|
|
start_step = int(jax.device_get(train_state.step)) |
|
pbar = tqdm.tqdm( |
|
range(start_step, config.num_train_steps), |
|
initial=start_step, |
|
total=config.num_train_steps, |
|
dynamic_ncols=True, |
|
) |
|
|
|
infos = [] |
|
gradient_step = 0 |
|
|
|
for step in pbar: |
|
with sharding.set_mesh(mesh): |
|
train_state, info = ptrain_step(train_rng, train_state, batch) |
|
infos.append(info) |
|
|
|
if (step + 1) % oft_config.grad_accumulation_steps == 0: |
|
gradient_step += 1 |
|
|
|
if gradient_step % config.log_interval == 0: |
|
stacked_infos = common_utils.stack_forest(infos) |
|
reduced_info = jax.device_get(jax.tree.map(jnp.mean, stacked_infos)) |
|
info_str = ", ".join(f"{k}={v:.4f}" for k, v in reduced_info.items()) |
|
pbar.write(f"Step {step}: {info_str}") |
|
wandb.log(reduced_info, step=step) |
|
infos = [] |
|
|
|
|
|
if oft_config.use_val_set and gradient_step % oft_config.val_freq == 0: |
|
val_metrics = run_validation(config, oft_config, train_state, data_loader, mesh, step) |
|
if val_metrics: |
|
wandb.log(val_metrics, step=step) |
|
pbar.write(f"Validation at step {step}: {val_metrics}") |
|
|
|
batch = next(data_iter) |
|
|
|
if (step % config.save_interval == 0 and step > start_step) or step == config.num_train_steps - 1: |
|
_checkpoints.save_state(checkpoint_manager, train_state, data_loader, step) |
|
|
|
logging.info("Waiting for checkpoint manager to finish") |
|
checkpoint_manager.wait_until_finished() |
|
|
|
|
|
if __name__ == "__main__": |
|
main(_config.cli()) |
|
|