| def auto_get_module_keys(module, max_depth=0, prefix_list=None, current_depth=0, current_prefix=""): |
| """ |
| get all submodule keys of a module, support setting recursion depth and prefix list. |
| |
| :param module: the module to traverse. |
| :param max_depth: the maximum recursion depth, default is 1. |
| :param prefix_list: only include modules with specified prefix, default is None means no restriction. |
| :param current_depth: the current recursion depth, internal use. |
| :param current_prefix: the current prefix, internal use. |
| :return: the list of module keys. |
| """ |
| if current_depth > max_depth: |
| return [] |
|
|
| module_keys = [] |
| for name, sub_module in module.named_children(): |
| full_name = f"{current_prefix}.{name}" if current_prefix else name |
| if prefix_list is None or any(full_name.startswith(prefix) for prefix in prefix_list): |
| module_keys.append(full_name) |
| module_keys.extend(auto_get_module_keys(sub_module, max_depth, prefix_list, current_depth + 1, full_name)) |
| return module_keys |
|
|
|
|
| def is_module_trainable(module): |
| """ |
| check if a module is trainable: if the module itself has parameters, then all its parameters require_grad must be True; |
| if the module itself has no parameters, then its trainability depends on its submodules. |
| """ |
| params = list(module.parameters(recurse=False)) |
| if params: |
| return all(p.requires_grad for p in params) |
| else: |
| |
| return True |
|
|
|
|
| def auto_get_trainable_modules(module, prefix="", max_depth=None): |
| """ |
| recursively traverse the module, return the list of all trainable module names. |
| if all submodules of a module are trainable, then only return the name of the parent module, no longer recursively output the names of its submodules. |
| |
| parameters: |
| - module: the module to traverse. |
| - prefix: the name prefix of the current module (internal use). |
| - max_depth: the maximum recursion depth, None means infinite recursion. |
| |
| return: |
| a list of module names. |
| """ |
| |
| children = list(module.named_children()) |
|
|
| |
| if (max_depth is not None and max_depth <= 0) or not children: |
| return [prefix] if prefix and is_module_trainable(module) else [] |
|
|
| child_keys = [] |
| all_children_trainable = True |
| for name, child in children: |
| full_name = f"{prefix}.{name}" if prefix else name |
| |
| keys = auto_get_trainable_modules(child, full_name, None if max_depth is None else max_depth - 1) |
| if not keys: |
| |
| if is_module_trainable(child): |
| keys = [full_name] |
| else: |
| all_children_trainable = False |
| else: |
| |
| if len(keys) > 1: |
| all_children_trainable = False |
| child_keys.extend(keys) |
|
|
| |
| if is_module_trainable(module) and all_children_trainable and child_keys: |
| return [prefix] if prefix else child_keys |
| else: |
| return child_keys |
|
|
|
|
| def print_freeze_status(self): |
| """ |
| for each top-level submodule, if all its parameters are in the same state (all frozen or all trainable), only print the top-level module. |
| if some top-level submodule has mixed parameter states (some frozen, some trainable), list the state of each parameter under the submodule. |
| """ |
| from collections import defaultdict |
|
|
| |
| status_dict = defaultdict(lambda: {"Frozen": 0, "Trainable": 0, "params": []}) |
| for full_name, param in self.named_parameters(): |
| |
| top_module = full_name.split(".", 1)[0] |
| state = "Frozen" if not param.requires_grad else "Trainable" |
| status_dict[top_module]["params"].append((full_name, state)) |
| status_dict[top_module][state] += 1 |
|
|
| print("=== module parameter freezing status ===") |
| for top_module, info in status_dict.items(): |
| frozen_count = info["Frozen"] |
| trainable_count = info["Trainable"] |
|
|
| if frozen_count > 0 and trainable_count == 0: |
| |
| print(f"{top_module:40s} | all Frozen ({frozen_count} parameters)") |
| elif trainable_count > 0 and frozen_count == 0: |
| |
| print(f"{top_module:40s} | all Trainable ({trainable_count} parameters)") |
| else: |
| |
| print(f"{top_module:40s} | mixed state → Frozen: {frozen_count}, Trainable: {trainable_count}") |
| for pname, pstate in info["params"]: |
| print(f" {pname:60s} | {pstate}") |
| print("=========================\n") |
|
|
|
|
|
|
| class Registry: |
| def __init__(self, name: str): |
| self.name = name |
| self._registry = {} |
|
|
| def register(self, key: str): |
| """Decorator: register a builder function or class""" |
| def decorator(framework_class): |
| if key in self._registry: |
| |
| pass |
| self._registry[key] = framework_class |
| return framework_class |
| return decorator |
| |
| def __getitem__(self, key): |
| return self._registry[key] |
| |
| def list(self): |
| """ |
| List currently registered keys; if with_values=True (not used here) return mapping {key: value_obj}. |
| Using class name as value is also intuitive, e.g., framework.__name__. |
| """ |
| return {k: v for k, v in self._registry.items()} |
|
|
| FRAMEWORK_REGISTRY = Registry("frameworks") |
|
|
|
|
|
|
| from starVLA.training.trainer_utils import initialize_overwatch |
| import os |
| import json |
| from pathlib import Path |
| from omegaconf import OmegaConf |
|
|
| |
| overwatch = initialize_overwatch(__name__) |
|
|
| def read_mode_config(pretrained_checkpoint): |
| """ |
| Same as read_model_config (legacy duplicate kept for backward compatibility). |
| |
| Args: |
| pretrained_checkpoint: Path to a .pt checkpoint file. |
| |
| Returns: |
| tuple: |
| vla_cfg (dict) |
| norm_stats (dict) |
| """ |
| if os.path.isfile(pretrained_checkpoint): |
| overwatch.info(f"Loading from local checkpoint path `{(checkpoint_pt := Path(pretrained_checkpoint))}`") |
|
|
| |
| assert checkpoint_pt.suffix in (".pt", ".safetensors"), \ |
| f"Unsupported checkpoint suffix `{checkpoint_pt.suffix}`, expected `.pt` or `.safetensors`" |
| run_dir = checkpoint_pt.parents[1] |
|
|
| |
| config_yaml, dataset_statistics_json = run_dir / "config.yaml", run_dir / "dataset_statistics.json" |
| assert config_yaml.exists(), f"Missing `config.yaml` for `{run_dir}`" |
| assert dataset_statistics_json.exists(), f"Missing `dataset_statistics.json` for `{run_dir}`" |
|
|
| |
| |
| try: |
| ocfg = OmegaConf.load(str(config_yaml)) |
| global_cfg = OmegaConf.to_container(ocfg, resolve=True) |
| except Exception as e: |
| overwatch.error(f"❌ Failed to load YAML config `{config_yaml}`: {e}") |
| raise |
|
|
| |
| with open(dataset_statistics_json, "r") as f: |
| norm_stats = json.load(f) |
| else: |
| overwatch.error(f"❌ Pretrained checkpoint `{pretrained_checkpoint}` does not exist.") |
| raise FileNotFoundError(f"Pretrained checkpoint `{pretrained_checkpoint}` does not exist.") |
| return global_cfg, norm_stats |
|
|
|
|