cross13tasks / code /training /trainer_utils /trainer_tools.py
Timsty's picture
Upload folder using huggingface_hub
e94400c verified
"""
metrics.py
Utility classes defining a Metrics container and multiple Trackers to enable model/stage-specific logging to various
endpoints (e.g., JSONL local logs, Weights & Biases).
"""
from typing import Tuple
import re
import json
import numpy as np
import torch
from accelerate.logging import get_logger
logger = get_logger(__name__)
# === Define Tracker Interface ===
#
# utils/cli_parser.py
def normalize_dotlist_args(args):
"""
Convert ['--x.y', 'val'] and ['--flag'] → ['x.y=val', 'flag=true']
"""
normalized = []
skip = False
for i in range(len(args)):
if skip:
skip = False
continue
arg = args[i]
if arg.startswith("--"):
key = arg.lstrip("-")
if "=" in key:
normalized.append(key)
elif i + 1 < len(args) and not args[i + 1].startswith("--"):
normalized.append(f"{key}={args[i + 1]}")
skip = True
else:
normalized.append(f"{key}=true")
else:
pass # skip orphaned values
return normalized
def build_param_lr_groups(model, cfg):
"""
build multiple param groups based on cfg.trainer.learning_rate.
support specifying different learning rates for different modules, the rest use base.
Args:
vla: nn.Module model object
cfg: config object, requires cfg.trainer.learning_rate dictionary
Returns:
List[Dict]: param_groups that can be used to build optimizer with torch.optim
"""
lr_cfg = cfg.trainer.learning_rate
base_lr = lr_cfg.get("base", 1e-4) # default base learning rate
freeze_modules = cfg.trainer.get("freeze_modules", "")
if not isinstance(freeze_modules, str):
freeze_modules = ""
freeze_patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()]
used_params = set()
frozen_params = set()
param_groups = []
for freeze_path in freeze_patterns:
module = model
try:
for attr in freeze_path.split("."):
module = getattr(module, attr)
frozen_params.update(id(p) for p in module.parameters())
except AttributeError:
print(f"⚠️ freeze module path does not exist: {freeze_path}")
continue
for module_name, lr in lr_cfg.items():
if module_name == "base":
continue
# try to find the module under vla by module_name (support nested paths)
module = model
try:
for attr in module_name.split("."):
module = getattr(module, attr)
# filter out frozen parameters
params = [p for p in module.parameters() if id(p) not in frozen_params]
if params: # only add param group if there are trainable parameters
param_groups.append({"params": params, "lr": lr, "name": module_name})
used_params.update(id(p) for p in params)
except AttributeError:
ReferenceError(f"⚠️ module path `{module_name}` not found in vla")
# assign base learning rate to the remaining unused parameters (exclude frozen ones)
other_params = [p for p in model.parameters() if id(p) not in used_params and id(p) not in frozen_params]
if other_params:
param_groups.append({"params": other_params, "lr": base_lr, "name": "base"})
return param_groups
import torch.distributed as dist
def _is_main_process_dist() -> bool:
return (not dist.is_initialized()) or dist.get_rank() == 0
def only_main_process(func):
"""
decorator: only run in main process (rank=0)
"""
def wrapper(*args, **kwargs):
if dist.is_initialized() and dist.get_rank() != 0:
return None # non-main process does not execute
return func(*args, **kwargs)
return wrapper
from torchvision.ops import box_iou
from PIL import Image
def resize_images(images, target_size=(224, 224)):
"""
recursively resize all images in the nested list.
:param images: nested list of images or single image.
:param target_size: target size (width, height) after resizing.
:return: resized images list, keeping the original nested structure.
"""
if isinstance(images, Image.Image): # if it is a single PIL image
return images.resize(target_size)
elif isinstance(images, list): # if it is a list, recursively process each element
return [resize_images(img, target_size) for img in images]
else:
raise ValueError("Unsupported image type or structure.")
class TrainerUtils:
@staticmethod
def freeze_backbones(model, freeze_modules=""):
"""
directly freeze the specified submodules based on the relative module path list (patterns), no longer recursively find all submodule names:
- patterns: read from config.trainer.freeze_modules, separated by commas to get the "relative path" list
for example "qwen_vl_interface, action_model.net",
it means to freeze model.qwen_vl_interface and model.action_model.net.
Args:
model: nn.Module model object
freeze_modules: relative module path list (patterns)
Returns:
model: nn.Module model object
return:
- model:
"""
frozen = []
print("#"*30)
print(freeze_modules)
if freeze_modules and type(freeze_modules) == str:
# split and remove whitespace
patterns = [p.strip() for p in freeze_modules.split(",") if p.strip()] if freeze_modules else []
for path in patterns:
# split the "relative path" by dots, for example "action_model.net" → ["action_model", "net"]
attrs = path.split(".")
module = model
try:
for attr in attrs:
module = getattr(module, attr)
# if the module is successfully get, freeze it and its all submodule parameters
for param in module.parameters():
param.requires_grad = False
frozen.append(path)
except AttributeError:
# if the attribute does not exist, skip and print warning
print(f"⚠️ module path does not exist, cannot freeze: {path}")
continue
# accelerator.wait_for_everyone() # synchronize when distributed training
if _is_main_process_dist():
print(f"🔒 Frozen modules with re pattern: {frozen}")
return model
@staticmethod
def print_trainable_parameters(model):
"""
print the total number of parameters and trainable parameters of the model
:param model: PyTorch model instance
"""
if not _is_main_process_dist():
return
print("📊 model parameter statistics:")
num_params = sum(p.numel() for p in model.parameters())
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(
f"# Parameters (in millions): {num_params / 10**6:.3f} Total, {num_trainable_params / 10**6:.3f} Trainable"
)
return num_params, num_trainable_params
@staticmethod
def load_pretrained_backbones(model, checkpoint_path=None, reload_modules=None):
"""
load checkpoint:
- if reload_modules is set, load by path part
- otherwise → load the entire model parameters (overwrite model)
return:
replace, loaded_modules: list of module paths that successfully loaded parameters; if global load, then ["<full_model>"]
"""
if not checkpoint_path:
return []
if _is_main_process_dist():
print(f"📦 loading checkpoint: {checkpoint_path}")
try:
if _is_safetensors_path(checkpoint_path):
from safetensors.torch import load_file
checkpoint = load_file(checkpoint_path)
else:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
except Exception as e:
raise RuntimeError(f"❌ loading checkpoint failed: {e}")
loaded_modules = []
if reload_modules: # partial load
module_paths = [p.strip() for p in reload_modules.split(",") if p.strip()]
for path in module_paths:
reload_modules = path.split(".")
module = model
try:
for module_name in reload_modules: # find the module to modify level by level
module = getattr(module, module_name)
prefix = path + "."
sub_state_dict = {k[len(prefix) :]: v for k, v in checkpoint.items() if k.startswith(prefix)}
if sub_state_dict:
module.load_state_dict(sub_state_dict, strict=True)
if _is_main_process_dist():
print(f"✅ parameters loaded to module '{path}'")
loaded_modules.append(path)
else:
print(f"⚠️ parameters not found in checkpoint '{path}'")
except AttributeError:
print(f"❌ cannot find module path: {path}")
else: # full load
try:
model.load_state_dict(checkpoint, strict=False)
if _is_main_process_dist():
print("✅ loaded <full_model> model parameters")
loaded_modules = ["<full_model>"]
except Exception as e:
raise RuntimeError(f"❌ loading full model failed: {e}")
return model
@staticmethod
def print_freeze_status(model):
"""
print the freezing status of each parameter in the model
:param model: PyTorch model instance
"""
for name, param in model.named_parameters():
status = "Frozen" if not param.requires_grad else "Trainable"
print(f"{name:60s} | {status}")
@staticmethod
def setup_distributed_training(accelerator, *components):
"""
use Accelerator to prepare distributed training components
:param accelerator: Accelerate instance
:param components: any number of components (such as model, optimizer, dataloader, etc.)
:return: prepared distributed components (in the same order as input)
"""
# use accelerator.prepare method to wrap components
prepared_components = accelerator.prepare(*components)
return prepared_components
def save_full_checkpoint(self, completed_steps, checkpoint_dir, output_dir):
"""Save full training state (prepared components + RNG) for resume,
plus a standalone model weights file for deployment.
The standalone file format is controlled by ``self.config.trainer.save_format``
(``"pt"`` or ``"safetensors"``). Defaults to ``"pt"`` when unset.
Must be called after accelerator.prepare().
Args:
completed_steps: Current training step count.
checkpoint_dir: Directory to save checkpoints (e.g. results/<run_id>/checkpoints).
output_dir: Top-level run directory for summary.jsonl and config.
"""
from pathlib import Path
save_format = getattr(self.config.trainer, "save_format", "pt")
# Save full accelerator state for all prepared components.
state_dir = os.path.join(checkpoint_dir, f"steps_{completed_steps}")
use_safe = save_format == "safetensors"
self.accelerator.save_state(state_dir, safe_serialization=use_safe)
# Save standalone weights & metadata (main process only)
if self.accelerator.is_main_process:
import json as _json
# Save standalone model weights for deployment
state_dict = self.accelerator.get_state_dict(self.model)
if state_dict is not None:
if save_format == "safetensors":
from safetensors.torch import save_file
weights_path = os.path.join(
checkpoint_dir, f"steps_{completed_steps}_model.safetensors"
)
save_file(state_dict, weights_path)
else:
weights_path = os.path.join(
checkpoint_dir, f"steps_{completed_steps}_pytorch_model.pt"
)
torch.save(state_dict, weights_path)
# Append to summary log
summary_data = {"steps": completed_steps}
with open(os.path.join(output_dir, "summary.jsonl"), "a") as f:
f.write(_json.dumps(summary_data) + "\n")
self.accelerator.print(f"✅ Checkpoint saved at {state_dir}")
# Save accessed config if available
from starVLA.training.trainer_utils.config_tracker import AccessTrackedConfig
if isinstance(self.config, AccessTrackedConfig):
self.config.save_accessed_config(
Path(output_dir) / "config.yaml",
use_original_values=False,
)
self.accelerator.wait_for_everyone()
def resume_from_full_checkpoint(self, checkpoint_dir):
"""Load full training state from an accelerator state directory.
Must be called **after** accelerator.prepare() (DeepSpeed requirement).
Args:
checkpoint_dir: Path to a steps_N/ directory containing full state.
Returns:
int: The completed_steps parsed from directory name (steps_N), or 0.
"""
self.accelerator.load_state(checkpoint_dir)
self.accelerator.print(f"Resumed full training state from: {checkpoint_dir}")
# Parse completed_steps from directory name (e.g. "steps_5000")
dir_name = os.path.basename(checkpoint_dir)
match = re.match(r"^steps_(\d+)$", dir_name)
return int(match.group(1)) if match else 0
@staticmethod
def euclidean_distance(predicted: np.ndarray, ground_truth: np.ndarray) -> float:
return np.linalg.norm(predicted - ground_truth)
@staticmethod
def _reset_dataloader(dataloader, epoch_counter):
"""safe reset dataloader iterator"""
# 1. update epoch counter
epoch_counter += 1
# 2. set new epoch (distributed core)
if hasattr(dataloader, "sampler") and callable(getattr(dataloader.sampler, "set_epoch", None)):
dataloader.sampler.set_epoch(epoch_counter)
# 3. create new iterator
return iter(dataloader), epoch_counter
@staticmethod
def compute_grad_angle_with_stats(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> Tuple[float, float]:
"""
compute the cosine angle between two groups of gradient vectors (degrees), and calculate the average angle and variance.
grads_a, grads_v: gradient Tensor list corresponding to the same parameter list interface_params
return:
mean_angle_deg: average angle (degrees)
angle_variance: angle variance
"""
angle_degs = []
# compute the cosine angle between each gradient block grads_a[0].shape = 1280, 3, 14, 14
# grads_1 = grads_a[0][0] # [3, 14, 14]
# grads_2 = grads_v[0][0]
# grads_a = grads_1.view(-1, 3) # reshape to [196, 3]
# grads_v = grads_2.view(-1, 3)
# lang linear
# reshape to 14*14, 3
# layer
grads_action = grads_a[0] # [2048, 11008]
grads_action = grads_action[
:32, :7
] # only take the first 7 elements, avoid cosim failure in high-dimensional space
grads_vl = grads_v[0] # [2048, 11008]
grads_vl = grads_vl[
:32, :7
] # only take the first 32 elements, 7 dimensions, avoid cosim failure in high-dimensional space
for g_a, g_v in zip(grads_action, grads_vl):
dot = torch.sum(g_a * g_v)
norm_a_sq = torch.sum(g_a * g_a)
norm_v_sq = torch.sum(g_v * g_v)
# avoid division by zero
norm_a = torch.sqrt(norm_a_sq + 1e-16)
norm_v = torch.sqrt(norm_v_sq + 1e-16)
cos_sim = (dot / (norm_a * norm_v)).clamp(-1.0, 1.0)
angle_rad = torch.acos(cos_sim)
angle_deg = angle_rad * (180.0 / torch.pi)
angle_degs.append(angle_deg.item())
# compute the average angle and variance
angle_degs_tensor = torch.tensor(angle_degs)
mean_angle_deg = torch.mean(angle_degs_tensor).item()
angle_variance = torch.sqrt(torch.var(angle_degs_tensor)).item()
# accelerator.wait_for_everyone()
return mean_angle_deg, angle_variance
@staticmethod
def pcgrad_project(grads_a: list[torch.Tensor], grads_v: list[torch.Tensor]) -> list[torch.Tensor]:
"""
apply PCGrad projection to the second group of gradients grads_v, suppress negative transfer between grads_a and grads_v
if the dot product of two groups of gradients < 0, then:
grads_v <- grads_v - (dot / ||grads_a||^2) * grads_a
return the new grads_v list
"""
# first compute dot and ||grads_a||^2
dot, norm_a_sq = 0.0, 0.0
for g_a, g_v in zip(grads_a, grads_v):
dot += torch.sum(g_a * g_v)
norm_a_sq += torch.sum(g_a * g_a)
if dot < 0:
coeff = dot / (norm_a_sq + 1e-6)
# projection
grads_v = [g_v - coeff * g_a for g_a, g_v in zip(grads_a, grads_v)]
return grads_v
@staticmethod
def l1_distance(predicted: np.ndarray, ground_truth: np.ndarray) -> float:
"""Mean Absolute Error - 更直观的误差度量"""
return np.sum(np.abs(predicted - ground_truth))
@staticmethod
def eval_qwenpi(qwenpi, dataloader, num_batches=20):
"""
evaluate QwenQFormerDiT model, compute IoU and action distance.
Args:
qwenpi: QwenQFormerDiT model instance.
dataloader: data loader.
num_batches: number of batches to evaluate.
Returns:
dict: contains IoU and action distance evaluation results.
"""
iou_scores = []
action_distances = []
count = 0
dataset_iter = iter(dataloader)
while count < num_batches:
try:
batch_samples = next(dataset_iter)
count += 1
except StopIteration:
break
# extract data
images = [example["image"] for example in batch_samples]
instructions = [example["lang"] for example in batch_samples]
actions = [example["action"] for example in batch_samples]
solutions = [example["solution"] for example in batch_samples]
# model prediction
predicted_solutions, normalized_actions = qwenpi.predict_action_withCoT(
images=images, instructions=instructions, use_ddim=False, num_ddim_steps=20
)
# extract and convert predicted results
parsed_solutions = []
for solution in predicted_solutions:
parsed_solution = TrainerUtils.extract_json_from_string(solution)
parsed_solutions.append(parsed_solution)
# compute IoU
for pred_dict, gt_dict in zip(parsed_solutions, solutions):
pred_pick_bbox = torch.tensor(pred_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)
gt_pick_bbox = torch.tensor(gt_dict["pick"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)
pred_place_bbox = torch.tensor(pred_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)
gt_place_bbox = torch.tensor(gt_dict["place"]["bbox_2d"], dtype=torch.float32).unsqueeze(0)
pick_iou = box_iou(pred_pick_bbox, gt_pick_bbox).item()
place_iou = box_iou(pred_place_bbox, gt_place_bbox).item()
iou_scores.append({"pick_iou": pick_iou, "place_iou": place_iou})
# compute action distance
actions = np.array(actions) # convert to numpy array
num_pots = np.prod(actions.shape) # B*len*dim
action_distance = TrainerUtils.euclidean_distance(normalized_actions, actions)
average_action_distance = action_distance / num_pots
action_distances.append(average_action_distance)
# summarize results
avg_action_distance = np.mean(action_distances)
return {"iou_scores": iou_scores, "average_action_distance": avg_action_distance}
@staticmethod
def extract_json_from_string(input_string):
"""
extract valid JSON part from string and convert to dictionary.
Args:
input_string (str): string containing extra characters.
Returns:
dict: dictionary extracted and parsed.
"""
json_match = re.search(r"{.*}", input_string, re.DOTALL)
if json_match:
json_str = json_match.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"JSON decode failed: {e}")
return None
else:
print("No valid JSON part found")
return None
def _get_latest_checkpoint(self, checkpoint_dir):
"""Find the latest checkpoint in the directory based on step number.
Supports both new directory format (steps_N/) and legacy file format
(steps_N_pytorch_model.pt). Prefers new directory format when both exist
at the same step.
"""
if not os.path.exists(checkpoint_dir):
self.accelerator.print(f"No checkpoint directory found at {checkpoint_dir}")
return None, 0
checkpoints_with_steps = []
for entry in os.listdir(checkpoint_dir):
full_path = os.path.join(checkpoint_dir, entry)
# New format: steps_N/ directories (with training_state.json inside)
dir_match = re.match(r"^steps_(\d+)$", entry)
if dir_match and os.path.isdir(full_path):
step = int(dir_match.group(1))
# Directory checkpoints contain full accelerator state for resume.
checkpoints_with_steps.append((full_path, step, "dir"))
continue
# Weight-only files: steps_N_pytorch_model.pt or steps_N_model.safetensors
file_match = re.match(r"^steps_(\d+)_(?:pytorch_model\.pt|model\.safetensors)$", entry)
if file_match and os.path.isfile(full_path):
step = int(file_match.group(1))
checkpoints_with_steps.append((full_path, step, "file"))
if not checkpoints_with_steps:
self.accelerator.print(f"No checkpoints found in {checkpoint_dir}")
return None, 0
# Sort by step number, then by type priority (dir > file) so directory wins ties.
type_priority = {"file": 0, "dir": 1}
checkpoints_with_steps.sort(key=lambda x: (x[1], type_priority[x[2]]))
latest_path, completed_steps, fmt = checkpoints_with_steps[-1]
self.accelerator.print(f"Latest checkpoint found: {latest_path} (format={fmt})")
return latest_path, completed_steps
import os
def is_main_process():
rank = int(os.environ.get("RANK", 0)) # if RANK is not set, default to 0
return rank == 0
def _is_safetensors_path(path):
"""Check if a path refers to a safetensors file."""
return str(path).endswith(".safetensors")