cross13tasks / code /training /train_qwenlatent.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# Copyright 2025 starVLA community. All rights reserved.
# Licensed under the MIT License, Version 1.0 (the "License");
# Implemented by [Jinhui YE / HKUST University] in [2025].
import sys
sys.path.append("/mnt/data/fangyu/code/reward_new")
"""
StarVLA’s trainer is built directly on native PyTorch + Accelerate + DeepSpeed, keeping the loop explicit and easy to hack.
Conventions:
1. Store runtime state in dicts where possible (simplifies data info, procesing info, config, etc).
2. Use multiple dataloaders to adapt heterogeneous data types / task mixtures.
3. Put each training strategy in its own `trainer_*.py` file (avoid large if‑else chains).
"""
import warnings
warnings.filterwarnings("ignore")
# Standard Library
import argparse
import json
import os
os.environ["WANDB_API_KEY"] = "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"
from pathlib import Path
from typing import Tuple
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
import glob
import re
# Third-Party Libraries
import torch
import torch.distributed as dist
import wandb
import yaml
from accelerate import Accelerator, DeepSpeedPlugin
from accelerate.logging import get_logger
from accelerate.utils import set_seed, DistributedType
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import AutoProcessor, get_scheduler
# Local Modules
from starVLA.training.trainer_utils.trainer_tools import normalize_dotlist_args
from starVLA.model.framework import build_framework
from starVLA.training.trainer_utils.trainer_tools import TrainerUtils
from starVLA.training.trainer_utils.trainer_tools import build_param_lr_groups
deepspeed_plugin = DeepSpeedPlugin()
accelerator = Accelerator(deepspeed_plugin=deepspeed_plugin)
accelerator.print(accelerator.state)
# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize Overwatch =>> Wraps `logging.Logger`
from accelerate.logging import get_logger
logger = get_logger(__name__)
def load_fast_tokenizer():
fast_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True)
return fast_tokenizer
def setup_directories(cfg) -> Path:
"""create output directory and save config"""
cfg.output_dir = os.path.join(cfg.run_root_dir, cfg.run_id)
output_dir = Path(cfg.output_dir)
if not dist.is_initialized() or dist.get_rank() == 0:
# create output directory and checkpoint directory
os.makedirs(output_dir, exist_ok=True)
os.makedirs(output_dir / "checkpoints", exist_ok=True)
# save config
OmegaConf.save(cfg, output_dir / "config.yaml")
with open(output_dir / "config.yaml", "r") as f_yaml, open(output_dir / "config.json", "w") as f_json:
yaml_cfg = yaml.safe_load(f_yaml)
json.dump(yaml_cfg, f_json, indent=2)
return output_dir
def build_model(cfg) -> torch.nn.Module:
"""build model framework"""
logger.info(f"Loading Base VLM `{cfg.framework.qwenvl.base_vlm}` from ID/Path")
model = build_framework(cfg)
return model
# here changes need to 📦 encapsulate Dataloader
from starVLA.dataloader import build_dataloader
def prepare_data(cfg, accelerator, output_dir) -> Tuple[DataLoader, DataLoader]:
"""prepare training data"""
# VLA data loader
logger.info(f"Creating VLA Dataset with Mixture `{cfg.datasets.vla_data.data_mix}`")
vla_train_dataloader = build_dataloader(cfg=cfg, dataset_py=cfg.datasets.vla_data.dataset_py)
accelerator.dataloader_config.dispatch_batches = False
dist.barrier()
return vla_train_dataloader
def get_warmup_stable_cosine_scheduler(optimizer, num_warmup_steps, num_stable_steps, num_training_steps, min_lr_ratio=0.01):
"""
Warmup → Stable → Cosine Decay scheduler
Args:
optimizer: PyTorch optimizer
num_warmup_steps: warmup 阶段步数
num_stable_steps: 保持 max_lr 的步数(在 warmup 之后)
num_training_steps: 总训练步数
min_lr_ratio: 最终 lr / max_lr 的比例
Returns:
LambdaLR scheduler
"""
import math
def lr_lambda(current_step):
# Warmup 阶段:线性增长
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
# Stable 阶段:保持 max_lr
stable_end = num_warmup_steps + num_stable_steps
if current_step < stable_end:
return 1.0
# Cosine decay 阶段
decay_steps = num_training_steps - stable_end
if decay_steps <= 0:
return min_lr_ratio
progress = float(current_step - stable_end) / float(decay_steps)
return min_lr_ratio + (1.0 - min_lr_ratio) * 0.5 * (1.0 + math.cos(math.pi * progress))
# 为每个参数组提供相同的 lr_lambda(支持多参数组优化器)
num_param_groups = len(optimizer.param_groups)
return torch.optim.lr_scheduler.LambdaLR(optimizer, [lr_lambda] * num_param_groups)
def setup_optimizer_and_scheduler(model, cfg) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]:
"""set optimizer and scheduler"""
# initialize optimizer
param_groups = build_param_lr_groups(model=model, cfg=cfg)
optimizer = torch.optim.AdamW(
param_groups,
lr=cfg.trainer.learning_rate.base,
betas=tuple(cfg.trainer.optimizer.betas),
weight_decay=cfg.trainer.optimizer.weight_decay,
eps=cfg.trainer.optimizer.eps,
)
# print optimizer group info
if dist.is_initialized() and dist.get_rank() == 0:
for i, group in enumerate(optimizer.param_groups):
logger.info(f"LR Group {group['name']}: lr={group['lr']}, num_params={len(group['params'])}")
# initialize learning rate scheduler
if cfg.trainer.lr_scheduler_type == "warmup_stable_cosine":
# 自定义 scheduler: Warmup → Stable → Cosine Decay
min_lr_ratio = cfg.trainer.scheduler_specific_kwargs.get("min_lr_ratio", 0.01)
num_stable_steps = cfg.trainer.get("num_stable_steps", 0)
lr_scheduler = get_warmup_stable_cosine_scheduler(
optimizer=optimizer,
num_warmup_steps=cfg.trainer.num_warmup_steps,
num_stable_steps=num_stable_steps,
num_training_steps=cfg.trainer.max_train_steps,
min_lr_ratio=min_lr_ratio,
)
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(f"Using warmup_stable_cosine scheduler: warmup={cfg.trainer.num_warmup_steps}, "
f"stable={num_stable_steps}, total={cfg.trainer.max_train_steps}, min_lr_ratio={min_lr_ratio}")
elif cfg.trainer.lr_scheduler_type == "onecycle":
# OneCycleLR: supports multiple param groups with different peak lrs.
scheduler_kwargs = cfg.trainer.scheduler_specific_kwargs or {}
pct_start = scheduler_kwargs.get("pct_start", None)
if pct_start is None:
pct_start = float(cfg.trainer.num_warmup_steps) / float(max(1, cfg.trainer.max_train_steps))
pct_start = max(0.0, min(1.0, float(pct_start)))
max_lrs = [group["lr"] for group in optimizer.param_groups]
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer=optimizer,
max_lr=max_lrs,
total_steps=cfg.trainer.max_train_steps,
pct_start=pct_start,
anneal_strategy=scheduler_kwargs.get("anneal_strategy", "cos"),
cycle_momentum=scheduler_kwargs.get("cycle_momentum", False),
div_factor=scheduler_kwargs.get("div_factor", 25.0),
final_div_factor=scheduler_kwargs.get("final_div_factor", 10000.0),
three_phase=scheduler_kwargs.get("three_phase", False),
)
if dist.is_initialized() and dist.get_rank() == 0:
logger.info(
"Using onecycle scheduler: total=%s, pct_start=%.6f, max_lrs=%s, anneal=%s, "
"div_factor=%s, final_div_factor=%s, cycle_momentum=%s, three_phase=%s",
cfg.trainer.max_train_steps,
pct_start,
max_lrs,
scheduler_kwargs.get("anneal_strategy", "cos"),
scheduler_kwargs.get("div_factor", 25.0),
scheduler_kwargs.get("final_div_factor", 10000.0),
scheduler_kwargs.get("cycle_momentum", False),
scheduler_kwargs.get("three_phase", False),
)
else:
# 使用 transformers 内置 scheduler
lr_scheduler = get_scheduler(
name=cfg.trainer.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=cfg.trainer.num_warmup_steps,
num_training_steps=cfg.trainer.max_train_steps,
scheduler_specific_kwargs=cfg.trainer.scheduler_specific_kwargs,
)
return optimizer, lr_scheduler
class VLATrainer(TrainerUtils):
def __init__(self, cfg, model, vla_train_dataloader, optimizer, lr_scheduler, accelerator):
self.config = cfg
self.model = model
self.vla_train_dataloader = vla_train_dataloader
# Note: optimizer/lr_scheduler are intentionally created in `prepare_training()`
# after we load checkpoints and freeze modules, to avoid empty param-groups in
# DeepSpeed ZeRO initialization.
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
self.accelerator = accelerator
self._printed_first_batch = False
# training status tracking
self.completed_steps = 0
self.total_batch_size = self._calculate_total_batch_size()
self._grad_norm_buffer: list[float] = []
self.training_mode = getattr(self.config.trainer, "mode", "default")
self.loss_weights_decay_steps = int(getattr(self.config.trainer, "loss_weights_decay_steps", 5000))
if self.loss_weights_decay_steps <= 0:
logger.warning(
f"Invalid loss_weights_decay_steps={self.loss_weights_decay_steps}, fallback to 1."
)
self.loss_weights_decay_steps = 1
def _debug_print_first_batch(self, batch) -> None:
if self._printed_first_batch or not self.accelerator.is_local_main_process:
return
self._printed_first_batch = True
sample = None
if isinstance(batch, list):
sample = batch[0] if len(batch) > 0 else None
elif isinstance(batch, dict):
sample = batch
if sample is None:
self.accelerator.print("First batch is empty.")
return
def _describe_value(value):
if hasattr(value, "shape"):
try:
return f"{type(value).__name__}(shape={tuple(value.shape)})"
except Exception:
return type(value).__name__
if isinstance(value, list):
inner = type(value[0]).__name__ if value else "empty"
return f"list(len={len(value)}, inner={inner})"
return type(value).__name__
self.accelerator.print(f"First batch type: {type(batch).__name__}")
if isinstance(batch, list):
self.accelerator.print(f"First batch size: {len(batch)}")
self.accelerator.print("First sample keys:")
for key, value in sample.items():
self.accelerator.print(f" - {key}: {_describe_value(value)}")
# Print full content for first 5 samples to inspect inputs.
if isinstance(batch, list):
max_samples = min(5, len(batch))
for i in range(max_samples):
self.accelerator.print(f"Sample[{i}] content:")
for key, value in batch[i].items():
if hasattr(value, "shape"):
try:
value_str = np.array2string(
value, threshold=np.inf, max_line_width=200
)
except Exception:
value_str = repr(value)
else:
value_str = repr(value)
self.accelerator.print(f" - {key}: {value_str}")
def prepare_training(self):
rank = dist.get_rank() if dist.is_initialized() else 0
seed = self.config.seed + rank if hasattr(self.config, "seed") else rank + 3047
set_seed(seed)
# load pretrained weights
# 如果 action_model 已经在 __init__ 中从 ckpt_path 加载了权重,需要保护它不被覆盖
action_model_ckpt_path = getattr(self.config.framework.action_model, "ckpt_path", None)
if action_model_ckpt_path:
# 保存 action_model 的权重用于验证
action_model_state_before = {
k: v.clone() for k, v in self.model.action_model.state_dict().items()
}
if hasattr(self.config.trainer, "pretrained_checkpoint") and self.config.trainer.pretrained_checkpoint:
pretrained_checkpoint = self.config.trainer.pretrained_checkpoint
reload_modules = (
self.config.trainer.reload_modules if hasattr(self.config.trainer, "reload_modules") else None
)
# 如果 action_model 有预加载的权重,且 reload_modules 未指定,则自动排除 action_model
if action_model_ckpt_path and not reload_modules:
# 检查 checkpoint 是否包含 action_model 的权重
try:
checkpoint = torch.load(pretrained_checkpoint, map_location="cpu")
has_action_model_keys = any(k.startswith("action_model.") for k in checkpoint.keys())
if has_action_model_keys:
logger.warning(
f"⚠️ pretrained_checkpoint contains action_model weights, but action_model "
f"was already loaded from {action_model_ckpt_path}. "
f"Will reload action_model from {action_model_ckpt_path} after loading checkpoint."
)
except Exception:
pass # 如果无法读取 checkpoint,继续正常流程
self.model = self.load_pretrained_backbones(self.model, pretrained_checkpoint, reload_modules=reload_modules)
# 如果 action_model 有预加载的权重,重新加载以确保不被覆盖
if action_model_ckpt_path:
logger.info(f"🔄 Reloading action_model from {action_model_ckpt_path} to ensure correct weights")
self.model.action_model.load_state_dict(
torch.load(action_model_ckpt_path, map_location="cpu"), strict=True
)
# 验证权重是否被正确恢复
action_model_state_after = self.model.action_model.state_dict()
mismatched = []
for k in action_model_state_before.keys():
if not torch.equal(action_model_state_before[k], action_model_state_after[k]):
mismatched.append(k)
if mismatched:
logger.error(f"❌ action_model weights mismatch after reload: {mismatched}")
else:
logger.info("✅ action_model weights verified after checkpoint loading")
# print model trainable parameters:
self.print_trainable_parameters(self.model)
# build optimizer and scheduler AFTER freezing (critical for DeepSpeed ZeRO)
self.optimizer, self.lr_scheduler = setup_optimizer_and_scheduler(model=self.model, cfg=self.config)
# initialize distributed training components
# 注意:不传入 lr_scheduler,避免被 AcceleratedScheduler 包装(会导致 step 被调用 num_processes 倍)
self.model, self.optimizer, self.vla_train_dataloader = self.setup_distributed_training(
self.accelerator, # must be the first param
self.model,
self.optimizer,
self.vla_train_dataloader,
)
self._init_wandb()
self._init_checkpointing()
def _calculate_total_batch_size(self):
"""calculate global batch size"""
return (
self.config.datasets.vla_data.per_device_batch_size
* self.accelerator.num_processes
* self.accelerator.gradient_accumulation_steps
)
def _init_wandb(self):
"""initialize Weights & Biases"""
if self.accelerator.is_main_process:
wandb.init(
name=self.config.run_id,
dir=os.path.join(self.config.output_dir, "wandb"),
project=self.config.wandb_project,
entity=self.config.wandb_entity,
group="vla-train",
settings=wandb.Settings(
_disable_stats=False, # 确保启用系统监控
x_stats_sampling_interval=10.0, # 每10秒采样一次系统指标
),
)
def _init_checkpointing(self):
"""initialize checkpoint directory"""
self.checkpoint_dir = os.path.join(self.config.output_dir, "checkpoints")
os.makedirs(self.checkpoint_dir, exist_ok=True)
pretrained_checkpoint = getattr(self.config.trainer, "pretrained_checkpoint", None)
is_resume = getattr(self.config.trainer, "is_resume", False)
# resume train ckpt
if pretrained_checkpoint and is_resume:
self._load_checkpoint(self.config.resume_from_checkpoint)
def _load_checkpoint(self, checkpoint_path):
"""load checkpoint"""
self.accelerator.load_state(checkpoint_path)
self.accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
def _save_checkpoint(self):
"""save current training state"""
if self.accelerator.is_main_process:
checkpoint_path = os.path.join(self.checkpoint_dir, f"steps_{self.completed_steps}")
# save model state
state_dict = self.accelerator.get_state_dict(self.model)
torch.save(state_dict, checkpoint_path + "_pytorch_model.pt")
# save training metadata
summary_data = {
"steps": self.completed_steps,
}
with open(os.path.join(self.config.output_dir, "summary.jsonl"), "a") as f:
f.write(json.dumps(summary_data) + "\n")
self.accelerator.print(f"✅ Checkpoint saved at {checkpoint_path}")
# 删除旧的checkpoint,只保留最近的N个
max_checkpoints = getattr(self.config.trainer, "max_checkpoints_to_keep", None)
if max_checkpoints is not None and max_checkpoints > 0:
self._cleanup_old_checkpoints(max_checkpoints)
self.accelerator.wait_for_everyone()
def _cleanup_old_checkpoints(self, max_checkpoints: int):
"""删除旧的checkpoint,只保留最近的N个"""
# 只在主进程中执行,避免多进程竞态条件
if not self.accelerator.is_main_process:
return
# 获取所有checkpoint文件
checkpoint_pattern = os.path.join(self.checkpoint_dir, "steps_*_pytorch_model.pt")
checkpoint_files = glob.glob(checkpoint_pattern)
if len(checkpoint_files) <= max_checkpoints:
return
# 从文件名中提取步数,并按步数排序
def extract_steps(filepath):
match = re.search(r'steps_(\d+)_pytorch_model\.pt', filepath)
return int(match.group(1)) if match else 0
checkpoint_files.sort(key=extract_steps)
# 删除最旧的checkpoint
files_to_delete = checkpoint_files[:-max_checkpoints]
for filepath in files_to_delete:
try:
os.remove(filepath)
self.accelerator.print(f"🗑️ Deleted old checkpoint: {os.path.basename(filepath)}")
except Exception as e:
self.accelerator.print(f"⚠️ Failed to delete checkpoint {filepath}: {e}")
def _log_metrics(self, metrics):
"""record training metrics"""
if self.completed_steps % self.config.trainer.logging_frequency == 0:
# Average grad_norm over the logging window (cleared every emit).
if self._grad_norm_buffer:
metrics["grad_norm_pre_clip_avg"] = float(
sum(self._grad_norm_buffer) / len(self._grad_norm_buffer)
)
self._grad_norm_buffer.clear()
if dist.get_rank() == 0:
# add learning rate
metrics["learning_rate"] = self.lr_scheduler.get_last_lr()[0] # see lr group in yaml.trainer.learning_rate
# add epoch info
metrics["epoch"] = round(self.completed_steps / len(self.vla_train_dataloader), 2)
# record to W&B
wandb.log(metrics, step=self.completed_steps)
# debug output
gn_str = f"{metrics['grad_norm_pre_clip']:.4f}" if "grad_norm_pre_clip" in metrics else "N/A"
gn_avg_str = f"{metrics['grad_norm_pre_clip_avg']:.4f}" if "grad_norm_pre_clip_avg" in metrics else "N/A"
logger.info(
f"\nStep {self.completed_steps} | "
f"grad_norm_pre_clip={gn_str} | grad_norm_pre_clip_avg={gn_avg_str} | "
f"Metrics: {metrics}"
)
def _create_data_iterators(self):
"""create data iterators"""
self.vla_iter = iter(self.vla_train_dataloader)
# self.vlm_iter = iter(self.vlm_train_dataloader)
def _get_next_batch(self):
"""get next batch (automatically handle data loop)"""
try:
batch_vla = next(self.vla_iter)
except StopIteration:
if not hasattr(self, "vla_epoch_count"):
self.vla_epoch_count = 0
self.vla_iter, self.vla_epoch_count = TrainerUtils._reset_dataloader(
self.vla_train_dataloader, self.vla_epoch_count
)
batch_vla = next(self.vla_iter)
return batch_vla
def train(self):
"""execute training loop"""
# print training config
self._log_training_config()
# prepare data iterators
self._create_data_iterators()
# create progress bar
progress_bar = tqdm(
range(self.config.trainer.max_train_steps), disable=not self.accelerator.is_local_main_process
)
# main training loop
while self.completed_steps < self.config.trainer.max_train_steps:
# get data batch
t_start_data = time.perf_counter()
batch_vla = self._get_next_batch()
self._debug_print_first_batch(batch_vla)
t_end_data = time.perf_counter()
# execute training step
t_start_model = time.perf_counter()
step_metrics = self._train_step(batch_vla)
t_end_model = time.perf_counter()
# update progress
if self.accelerator.sync_gradients:
progress_bar.update(1)
self.completed_steps += 1
if self.accelerator.is_local_main_process:
progress_bar.set_postfix(
{
"data_times": f"{t_end_data - t_start_data:.3f}",
"model_times": f"{t_end_model - t_start_model:.3f}",
}
)
# evaluate model (reuse current training batch to avoid consuming extra samples)
if self.completed_steps % self.config.trainer.eval_interval == 0:
step_metrics = self.eval_action_model(step_metrics)
# record metrics
step_metrics["data_time"] = t_end_data - t_start_data
step_metrics["model_time"] = t_end_model - t_start_model
self._log_metrics(step_metrics)
# save checkpoint
if self.completed_steps % self.config.trainer.save_interval == 0 and self.completed_steps > 0:
self._save_checkpoint()
# check termination condition
if self.completed_steps >= self.config.trainer.max_train_steps:
break
# training end processing
self._finalize_training()
# execute evaluation step
def eval_action_model(self, step_metrics: dict = None, examples=None) -> float:
"""
Evaluate the model on the given dataset using the specified metric function.
:param eval_dataset: List of evaluation samples, each containing 'image', 'instruction', and 'action'.
:param metric_fn: Function to compute the distance between predicted and ground truth actions.
:return: Average metric score across the evaluation dataset.
"""
if examples is None:
examples = self._get_next_batch()
score = 0.0
# When using history, actions contain both history and future
# We only evaluate on the future part (predicted actions)
if self.model.num_history_steps > 0:
start = self.model.num_history_steps
end = start + self.model.chunk_size
actions = [example["action"][start:end] for example in examples] # label aligned with predicted future chunk
else:
actions = [example["action"][: self.model.chunk_size] for example in examples] # label aligned with prediction length
# Predict actions using the model
output_dict = self.model.predict_action(examples=examples)
if self.accelerator.is_main_process:
normalized_actions = output_dict["normalized_actions"] # B, T, D
actions = np.array(actions) # convert actions to numpy.ndarray
# B, Chunk, dim = actions.shape
num_pots = np.prod(actions.shape)
# Compute the metric score (L1 = MAE, 更直观)
score = TrainerUtils.l1_distance(normalized_actions, actions)
average_score = score / num_pots
step_metrics["mae_score"] = average_score
del examples
dist.barrier() # ensure all processes are synchronized
return step_metrics
def _log_training_config(self):
"""record training config"""
if self.accelerator.is_main_process:
logger.info("***** Training Configuration *****")
logger.info(f" Total optimization steps = {self.config.trainer.max_train_steps}")
logger.info(f" Per device batch size = {self.config.datasets.vla_data.per_device_batch_size}")
logger.info(f" Gradient accumulation steps = {self.config.trainer.gradient_accumulation_steps}")
logger.info(f" Total batch size = {self.total_batch_size}")
logger.info("***** LR Scheduler Debug Info *****")
logger.info(f" lr_scheduler type = {type(self.lr_scheduler)}")
base_scheduler = getattr(self.lr_scheduler, 'scheduler', self.lr_scheduler)
logger.info(f" base_scheduler type = {type(base_scheduler)}")
logger.info(f" initial last_epoch = {getattr(base_scheduler, 'last_epoch', 'N/A')}")
logger.info(f" initial lr = {self.lr_scheduler.get_last_lr()}")
logger.info(f" num_warmup_steps = {self.config.trainer.num_warmup_steps}")
logger.info(f" num_stable_steps = {self.config.trainer.get('num_stable_steps', 0)}")
logger.info(f" max_train_steps = {self.config.trainer.max_train_steps}")
logger.info(f" accelerator.num_processes = {self.accelerator.num_processes}")
logger.info(f" accelerator.gradient_accumulation_steps = {self.accelerator.gradient_accumulation_steps}")
logger.info(f" trainer.mode = {self.training_mode}")
logger.info(f" loss_weights_decay_steps = {self.loss_weights_decay_steps}")
def _get_aux_loss_decay_weight(self) -> float:
if self.training_mode != "decay_aux_loss":
return 1.0
progress = min(float(self.completed_steps) / float(self.loss_weights_decay_steps), 1.0)
return 1.0 - progress
@staticmethod
def _total_grad_norm_l2_local(parameters) -> float:
"""L2 norm over all grads (same recipe as torch.nn.utils.clip_grad_norm_). DeepSpeed-safe fallback when clip_grad_norm_ returns None."""
total_sq = 0.0
for p in parameters:
if p.grad is None:
continue
# float32 for stable norm under bf16/fp16 grads
param_norm = p.grad.detach().float().norm(2)
total_sq += float(param_norm) ** 2
return total_sq ** 0.5
@staticmethod
def _grad_norm_scalar(value) -> float:
if value is None:
return float("nan")
if isinstance(value, torch.Tensor):
return float(value.detach().item())
return float(value)
def _train_step(self, batch_vla, batch_vlm=None):
"""execute single training step"""
is_deepspeed = self.accelerator.distributed_type == DistributedType.DEEPSPEED
grad_norm_pre_clip = None
with self.accelerator.accumulate(self.model):
self.optimizer.zero_grad()
# VLA task forward propagation(传入 training_step 使各 rank 的 history 随机一致,避免不同步)
with torch.autocast("cuda", dtype=torch.bfloat16):
output_dict = self.model.forward(batch_vla, training_step=self.completed_steps)
align_loss = output_dict["align_loss"]
recon_loss = output_dict["recon_loss"]
predict_loss = output_dict["predict_loss"]
aux_loss_decay_weight = self._get_aux_loss_decay_weight()
if align_loss is not None and recon_loss is not None:
total_loss = (
self.config.trainer.loss_scale.align_loss * aux_loss_decay_weight * align_loss
+ self.config.trainer.loss_scale.recon_loss * aux_loss_decay_weight * recon_loss
+ predict_loss
)
else:
total_loss = predict_loss
# VLA backward propagation
self.accelerator.backward(total_loss)
# For non-DeepSpeed: clip explicitly and capture pre-clip norm before optimizer.step().
# For DeepSpeed: gradient clipping is handled by ds_config internally; calling
# clip_grad_norm_ here returns the *previous* step's norm (stored in engine._global_grad_norm
# which is only updated during optimizer.step()), so we skip it here and retrieve
# the norm after step() below.
if not is_deepspeed:
gc = getattr(self.config.trainer, "gradient_clipping", None)
max_norm = float(gc) if gc is not None else float("inf")
grad_norm_pre_clip = self.accelerator.clip_grad_norm_(
self.model.parameters(), max_norm
)
if grad_norm_pre_clip is None:
grad_norm_pre_clip = self._total_grad_norm_l2_local(self.model.parameters())
self.optimizer.step()
if self.accelerator.sync_gradients:
self.lr_scheduler.step()
# For DeepSpeed: gradient clipping is handled internally during optimizer.step(),
# which also populates engine._global_grad_norm. Calling clip_grad_norm_(inf)
# is a no-op for DeepSpeed and returns None, so we read _global_grad_norm directly.
if is_deepspeed:
gn = getattr(self.model, "_global_grad_norm", None)
if gn is None:
# Older DeepSpeed / different ZeRO stage: try accelerator fallback
gn = self.accelerator.clip_grad_norm_(self.model.parameters(), float("inf"))
grad_norm_pre_clip = gn
gn_scalar = self._grad_norm_scalar(grad_norm_pre_clip)
self._grad_norm_buffer.append(gn_scalar)
step_metrics = {
"align_loss": align_loss.item() if align_loss is not None else None,
"recon_loss": recon_loss.item() if recon_loss is not None else None,
"predict_loss": predict_loss.item(),
"aux_loss_decay_weight": aux_loss_decay_weight,
"grad_norm_pre_clip": gn_scalar,
}
return step_metrics
def _finalize_training(self):
"""training end processing"""
# save final model
if self.accelerator.is_main_process:
final_checkpoint = os.path.join(self.config.output_dir, "final_model")
os.makedirs(final_checkpoint, exist_ok=True)
state_dict = self.accelerator.get_state_dict(self.model)
torch.save(state_dict, os.path.join(final_checkpoint, "pytorch_model.pt"))
logger.info(f"Training complete. Final model saved at {final_checkpoint}")
# close W&B
if self.accelerator.is_main_process:
wandb.finish()
self.accelerator.wait_for_everyone()
def main(cfg) -> None:
logger.info("VLA Training :: Warming Up")
# create output directory and save config
output_dir = setup_directories(cfg=cfg)
# build model
vla = build_framework(cfg)
# prepare data
vla_train_dataloader = prepare_data(cfg=cfg, accelerator=accelerator, output_dir=output_dir)
# create trainer
# Run VLA Training
trainer = VLATrainer(
cfg=cfg,
model=vla,
vla_train_dataloader=vla_train_dataloader,
optimizer=None,
lr_scheduler=None,
accelerator=accelerator,
)
# execute training preparation
trainer.prepare_training()
# execute training
trainer.train()
# And... we're done!
logger.info("... and that's all, folks!")
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config_yaml", type=str, default="starVLA/config/training/starvla_cotrain_oxe.yaml", help="Path to YAML config")
args, clipargs = parser.parse_known_args()
# Load YAML config & Convert CLI overrides to dotlist config
cfg = OmegaConf.load(args.config_yaml)
dotlist = normalize_dotlist_args(clipargs) # Normalize CLI args to dotlist format
cli_cfg = OmegaConf.from_dotlist(dotlist)
cfg = OmegaConf.merge(cfg, cli_cfg)
# if cfg.is_debug:
if cfg.is_debug and dist.is_initialized() and dist.get_rank() == 0:
import debugpy
debugpy.listen(("0.0.0.0", 10092))
print("🔍 Rank 0 waiting for debugger attach on port 10092...")
debugpy.wait_for_client()
main(cfg)