| |
|
|
| import dataclasses |
| import json |
| import os |
| from pathlib import Path |
|
|
| import safetensors.torch |
| import torch |
| import torch.nn.functional as F |
| import tyro |
|
|
| import openpi.models.pi0_config |
| import openpi.models_pytorch.pi0_pytorch |
| import openpi.training.config as _config |
|
|
|
|
| @dataclasses.dataclass |
| class Args: |
| single_ckpt: str |
| config_name: str |
| output_path: str |
|
|
|
|
| def _build_model_config(config: _config.TrainConfig) -> openpi.models.pi0_config.Pi0Config: |
| if not isinstance(config.model, openpi.models.pi0_config.Pi0Config): |
| return openpi.models.pi0_config.Pi0Config( |
| dtype=config.pytorch_training_precision, |
| action_dim=config.model.action_dim, |
| action_horizon=config.model.action_horizon, |
| max_token_len=config.model.max_token_len, |
| paligemma_variant=getattr(config.model, "paligemma_variant", "gemma_2b"), |
| action_expert_variant=getattr(config.model, "action_expert_variant", "gemma_300m"), |
| pi05=getattr(config.model, "pi05", False), |
| arm_action_dims=getattr(config.model, "arm_action_dims", None), |
| action_expert_mode=getattr(config.model, "action_expert_mode", None), |
| ) |
|
|
| model_cfg = config.model |
| object.__setattr__(model_cfg, "dtype", config.pytorch_training_precision) |
| return model_cfg |
|
|
|
|
| def _copy_factorized_heads(model, weight_in, bias_in, weight_out, bias_out) -> None: |
| hidden_width = weight_in.shape[0] |
| with torch.no_grad(): |
| model.action_in_proj_arms[0].weight.copy_(weight_in[:, 0:16]) |
| model.action_in_proj_arms[0].bias.zero_() |
| model.action_in_proj_arms[1].weight.copy_(weight_in[:, 16:32]) |
| model.action_in_proj_arms[1].bias.zero_() |
|
|
| if hasattr(model, "arm_token_fuse"): |
| fuse_weight = torch.zeros_like(model.arm_token_fuse.weight) |
| identity = torch.eye(hidden_width, dtype=fuse_weight.dtype) |
| fuse_weight[:, 0:hidden_width] = identity |
| fuse_weight[:, hidden_width : 2 * hidden_width] = identity |
| model.arm_token_fuse.weight.copy_(fuse_weight) |
| model.arm_token_fuse.bias.copy_(bias_in) |
|
|
| model.action_out_proj_arms[0].weight.copy_(weight_out[0:16, :]) |
| model.action_out_proj_arms[0].bias.copy_(bias_out[0:16]) |
| model.action_out_proj_arms[1].weight.copy_(weight_out[16:32, :]) |
| model.action_out_proj_arms[1].bias.copy_(bias_out[16:32]) |
|
|
|
|
| def _copy_split_expert_weights(model, single_state) -> None: |
| model_state = model.state_dict() |
| with torch.no_grad(): |
| for key, value in single_state.items(): |
| if not key.startswith("paligemma_with_expert.gemma_expert."): |
| continue |
| suffix = key.removeprefix("paligemma_with_expert.gemma_expert.") |
| left_key = f"paligemma_with_expert.left_gemma_expert.{suffix}" |
| right_key = f"paligemma_with_expert.right_gemma_expert.{suffix}" |
| model_state[left_key].copy_(value.to(dtype=model_state[left_key].dtype)) |
| model_state[right_key].copy_(value.to(dtype=model_state[right_key].dtype)) |
|
|
|
|
| def _expert_copy_max_abs_diff(model, single_state, target_prefix: str) -> float: |
| model_state = model.state_dict() |
| max_abs_diff = 0.0 |
| for key, value in single_state.items(): |
| if not key.startswith("paligemma_with_expert.gemma_expert."): |
| continue |
| suffix = key.removeprefix("paligemma_with_expert.gemma_expert.") |
| target_key = f"{target_prefix}{suffix}" |
| diff = (model_state[target_key].to(torch.float32) - value.to(torch.float32)).abs().max().item() |
| max_abs_diff = max(max_abs_diff, float(diff)) |
| return max_abs_diff |
|
|
|
|
| def main() -> None: |
| args = tyro.cli(Args) |
| config = _config.get_config(args.config_name) |
| model_cfg = _build_model_config(config) |
| if not model_cfg.use_parallel_action_heads: |
| raise ValueError(f"Config {args.config_name} does not use factorized or split action heads.") |
| if tuple(model_cfg.arm_action_dims) != (16, 16): |
| raise ValueError(f"Expected arm_action_dims=(16, 16), got {model_cfg.arm_action_dims}.") |
|
|
| parallel_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_cfg) |
| single_state = safetensors.torch.load_file(os.path.join(args.single_ckpt, "model.safetensors"), device="cpu") |
|
|
| missing, unexpected = parallel_model.load_state_dict(single_state, strict=False) |
|
|
| weight_in = single_state["action_in_proj.weight"] |
| bias_in = single_state["action_in_proj.bias"] |
| weight_out = single_state["action_out_proj.weight"] |
| bias_out = single_state["action_out_proj.bias"] |
|
|
| hidden_width = weight_in.shape[0] |
| if weight_in.shape[1] != 32 or weight_out.shape[0] != 32: |
| raise ValueError( |
| f"Expected single-head checkpoint with packed 32-dim actions, got in={tuple(weight_in.shape)} out={tuple(weight_out.shape)}." |
| ) |
|
|
| _copy_factorized_heads(parallel_model, weight_in, bias_in, weight_out, bias_out) |
| if model_cfg.use_split_action_expert: |
| _copy_split_expert_weights(parallel_model, single_state) |
|
|
| proj_in_dtype = parallel_model.action_in_proj_arms[0].weight.dtype |
| proj_out_dtype = parallel_model.action_out_proj_arms[0].weight.dtype |
| x = torch.randn(2, model_cfg.action_horizon, model_cfg.action_dim, dtype=proj_in_dtype) |
| x_left = x[:, :, 0:16] |
| x_right = x[:, :, 16:32] |
| suffix = torch.randn(2, model_cfg.action_horizon, hidden_width, dtype=proj_out_dtype) |
|
|
| metadata = { |
| "config_name": args.config_name, |
| "action_expert_mode": model_cfg.action_expert_mode, |
| "single_ckpt": args.single_ckpt, |
| "output_path": args.output_path, |
| "load_state_missing_keys": list(missing), |
| "load_state_unexpected_keys": list(unexpected), |
| } |
|
|
| with torch.no_grad(): |
| left_input_projection_max_abs_diff = float( |
| ( |
| F.linear(x_left, weight_in[:, 0:16].to(proj_in_dtype), None) |
| - parallel_model.action_in_proj_arms[0](x_left) |
| ) |
| .abs() |
| .max() |
| .item() |
| ) |
| right_input_projection_max_abs_diff = float( |
| ( |
| F.linear(x_right, weight_in[:, 16:32].to(proj_in_dtype), None) |
| - parallel_model.action_in_proj_arms[1](x_right) |
| ) |
| .abs() |
| .max() |
| .item() |
| ) |
| left_output_projection_max_abs_diff = float( |
| ( |
| F.linear(suffix, weight_out[0:16, :].to(proj_out_dtype), bias_out[0:16].to(proj_out_dtype)) |
| - parallel_model.action_out_proj_arms[0](suffix) |
| ) |
| .abs() |
| .max() |
| .item() |
| ) |
| right_output_projection_max_abs_diff = float( |
| ( |
| F.linear(suffix, weight_out[16:32, :].to(proj_out_dtype), bias_out[16:32].to(proj_out_dtype)) |
| - parallel_model.action_out_proj_arms[1](suffix) |
| ) |
| .abs() |
| .max() |
| .item() |
| ) |
|
|
| metadata.update( |
| { |
| "left_input_projection_max_abs_diff": left_input_projection_max_abs_diff, |
| "right_input_projection_max_abs_diff": right_input_projection_max_abs_diff, |
| "left_output_projection_max_abs_diff": left_output_projection_max_abs_diff, |
| "right_output_projection_max_abs_diff": right_output_projection_max_abs_diff, |
| } |
| ) |
|
|
| if model_cfg.action_expert_mode == "head_only_parallel": |
| input_max_abs_diff = float( |
| ( |
| F.linear(x, weight_in.to(proj_in_dtype), bias_in.to(proj_in_dtype)) |
| - parallel_model._project_action_inputs(x) |
| ) |
| .abs() |
| .max() |
| .item() |
| ) |
| output_max_abs_diff = float( |
| ( |
| F.linear(suffix, weight_out.to(proj_out_dtype), bias_out.to(proj_out_dtype)) |
| - parallel_model._project_action_outputs(suffix) |
| ) |
| .abs() |
| .max() |
| .item() |
| ) |
| metadata["input_projection_max_abs_diff"] = input_max_abs_diff |
| metadata["output_projection_max_abs_diff"] = output_max_abs_diff |
| metadata["warm_start_exact"] = input_max_abs_diff == 0.0 and output_max_abs_diff == 0.0 |
| else: |
| left_expert_max_abs_diff = _expert_copy_max_abs_diff( |
| parallel_model, |
| single_state, |
| "paligemma_with_expert.left_gemma_expert.", |
| ) |
| right_expert_max_abs_diff = _expert_copy_max_abs_diff( |
| parallel_model, |
| single_state, |
| "paligemma_with_expert.right_gemma_expert.", |
| ) |
| metadata["left_expert_max_abs_diff"] = left_expert_max_abs_diff |
| metadata["right_expert_max_abs_diff"] = right_expert_max_abs_diff |
| if parallel_model.paligemma_with_expert.cross_arm_comm is not None: |
| metadata["cross_arm_comm_init"] = [ |
| float(value) for value in parallel_model.paligemma_with_expert.cross_arm_comm.detach().cpu().tolist() |
| ] |
| metadata["warm_start_exact"] = ( |
| left_input_projection_max_abs_diff == 0.0 |
| and right_input_projection_max_abs_diff == 0.0 |
| and left_output_projection_max_abs_diff == 0.0 |
| and right_output_projection_max_abs_diff == 0.0 |
| and left_expert_max_abs_diff == 0.0 |
| and right_expert_max_abs_diff == 0.0 |
| ) |
|
|
| output_dir = Path(args.output_path) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| safetensors.torch.save_model(parallel_model, output_dir / "model.safetensors") |
| (output_dir / "config.json").write_text(json.dumps(dataclasses.asdict(model_cfg), indent=2, sort_keys=True)) |
| (output_dir / "init_parallel_metadata.json").write_text(json.dumps(metadata, indent=2, sort_keys=True)) |
|
|
| print(f"config_name: {args.config_name}") |
| print(f"action_expert_mode: {model_cfg.action_expert_mode}") |
| print(f"single_ckpt: {args.single_ckpt}") |
| print(f"output_path: {args.output_path}") |
| print(f"load_state_missing_keys_count: {len(missing)}") |
| print(f"load_state_missing_keys: {list(missing)}") |
| print(f"load_state_unexpected_keys_count: {len(unexpected)}") |
| print(f"load_state_unexpected_keys: {list(unexpected)}") |
| for key in sorted(metadata): |
| if key in {"config_name", "action_expert_mode", "single_ckpt", "output_path", "load_state_missing_keys", "load_state_unexpected_keys"}: |
| continue |
| print(f"{key}: {metadata[key]}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|