#!/usr/bin/env python3 """ Load a JAX model and print all parameter keys, with optional conversion to PyTorch. This script loads a JAX model checkpoint using orbax and can either: 1. Print out all the parameter keys in a hierarchical structure for inspection 2. Convert the JAX model to PyTorch format using our PI0Pytorch model Usage: # Just inspect keys: python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only # Convert to PyTorch: python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output Example: # pi0_droid python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch # pi0_aloha_sim python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch # pi05_droid python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch """ import json import os import pathlib import shutil from typing import Literal from flax.nnx import traversals import numpy as np import orbax.checkpoint as ocp import safetensors import torch import tyro import openpi.models.gemma import openpi.models.model import openpi.models.pi0_config import openpi.models_pytorch.pi0_pytorch from openpi.training import utils import openpi.training.config as _config def slice_paligemma_state_dict(state_dict, config): """Convert PaliGemma JAX parameters to PyTorch format.""" suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" # patch embeddings jax_key = f"img/embedding/kernel{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1) jax_key = f"img/embedding/bias{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias" state_dict[pytorch_key] = state_dict.pop(jax_key) # positional embeddings jax_key = f"img/pos_embedding{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size) # extract vision layers to be sliced at index 0. There are 27 layers in the base model. encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") encoderblock_attention_0_key_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}" ) encoderblock_attention_0_key_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}" ) encoderblock_attention_0_value_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}" ) encoderblock_attention_0_value_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}" ) encoderblock_attention_0_query_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}" ) encoderblock_attention_0_query_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}" ) encoderblock_attention_0_out_kernel = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}" ) encoderblock_attention_0_out_bias = state_dict.pop( f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}" ) for i in range(config.vision_config.num_hidden_layers): state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight" ] = encoderblock_layernorm0_scale[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias" ] = encoderblock_layernorm0_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight" ] = encoderblock_layernorm1_scale[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias" ] = encoderblock_layernorm1_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight" ] = encoderblock_mlp_dense0_kernel[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias" ] = encoderblock_mlp_dense0_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight" ] = encoderblock_mlp_dense1_kernel[i].transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias" ] = encoderblock_mlp_dense1_bias[i] state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight" ] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias" ] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight" ] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias" ] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight" ] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias" ] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight" ] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() state_dict[ f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias" ] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) jax_key = f"img/Transformer/encoder_norm/scale{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() jax_key = f"img/Transformer/encoder_norm/bias{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias" state_dict[pytorch_key] = state_dict.pop(jax_key) # multimodal projector jax_key = f"img/head/kernel{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight" state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() jax_key = f"img/head/bias{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias" state_dict[pytorch_key] = state_dict.pop(jax_key) # text decoder (gemma) jax_key = f"llm/embedder/input_embedding{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" state_dict[pytorch_key] = state_dict.pop(jax_key) # pop the einsum attention + mlp representations llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") for i in range(config.text_config.num_hidden_layers): q_proj_weight_reshaped = ( llm_attention_q_einsum[i] .transpose(0, 2, 1) .reshape( config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size ) ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = ( q_proj_weight_reshaped ) k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = ( k_proj_weight_reshaped ) v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = ( v_proj_weight_reshaped ) o_proj_weight_reshaped = ( llm_attention_attn_vec_einsum[i] .transpose(2, 0, 1) .reshape( config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size ) ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = ( o_proj_weight_reshaped ) gate_proj_weight = llm_mlp_gating_einsum[i, 0] state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = ( gate_proj_weight.transpose() ) up_proj_weight = llm_mlp_gating_einsum[i, 1] state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = ( up_proj_weight.transpose() ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = ( llm_mlp_linear[i].transpose() ) state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = ( llm_input_layernorm[i] ) state_dict[ f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight" ] = llm_post_attention_layernorm[i] jax_key = f"llm/final_norm/scale{suffix}" pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight" state_dict[pytorch_key] = state_dict.pop(jax_key) expert_dict = {} final_state_dict = {} # Expert-related keys to extract (including pi05 Dense layer parameters) expert_keys = [ f"llm/final_norm_1/scale{suffix}", f"llm/final_norm_1/Dense_0/bias{suffix}", f"llm/final_norm_1/Dense_0/kernel{suffix}", f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", f"llm/layers/attn/kv_einsum_1/w{suffix}", f"llm/layers/attn/q_einsum_1/w{suffix}", f"llm/layers/mlp_1/gating_einsum{suffix}", f"llm/layers/mlp_1/linear{suffix}", f"llm/layers/pre_attention_norm_1/scale{suffix}", f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}", f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}", f"llm/layers/pre_ffw_norm_1/scale{suffix}", f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}", f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}", ] for key, value in state_dict.items(): if key not in expert_keys: final_state_dict[key] = torch.from_numpy(value) else: expert_dict[key] = value return final_state_dict, expert_dict def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05): """Convert Gemma JAX parameters to PyTorch format.""" # Add missing attributes to config if they don't exist if not hasattr(config, "vocab_size"): config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE if not hasattr(config, "hidden_size"): config.hidden_size = config.width if not hasattr(config, "num_hidden_layers"): config.num_hidden_layers = config.depth if not hasattr(config, "num_attention_heads"): config.num_attention_heads = config.num_heads suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") # Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0) if "pi05" in checkpoint_dir: # Pi05 with adaptive normalization llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}") llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}") llm_input_layernorm_kernel = state_dict.pop( f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}" ) llm_post_attention_layernorm_kernel = state_dict.pop( f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}" ) else: # Regular pi0 with standard RMSNorm llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") for i in range(config.num_hidden_layers): q_proj_weight_reshaped = ( llm_attention_q_einsum[i] .transpose(0, 2, 1) .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = ( q_proj_weight_reshaped ) k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = ( k_proj_weight_reshaped ) v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = ( v_proj_weight_reshaped ) o_proj_weight_reshaped = ( llm_attention_attn_vec_einsum[i] .reshape(config.num_attention_heads * config.head_dim, config.hidden_size) .transpose(1, 0) ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = ( o_proj_weight_reshaped ) gate_proj_weight = llm_mlp_gating_einsum[i, 0] state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = ( gate_proj_weight.transpose() ) up_proj_weight = llm_mlp_gating_einsum[i, 1] state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = ( up_proj_weight.transpose() ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[ i ].transpose() if "pi05" in checkpoint_dir: # Pi05 with adaptive normalization - use Dense layer parameters directly state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = ( llm_input_layernorm_bias[i] ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = ( llm_post_attention_layernorm_bias[i] ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = ( llm_input_layernorm_kernel[i].transpose() ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = ( llm_post_attention_layernorm_kernel[i].transpose() ) else: # Regular pi0 with standard RMSNorm state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = ( llm_input_layernorm[i] ) state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = ( llm_post_attention_layernorm[i] ) # Handle final norm layer if "pi05" in checkpoint_dir: # Pi05 with adaptive normalization - use Dense layer parameters directly final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}") final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}") state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose() else: # Regular pi0 with standard RMSNorm state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop( f"llm/final_norm_{num_expert}/scale{suffix}" ) # state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. final_state_dict = {} for key, value in state_dict.items(): if not isinstance(value, torch.Tensor): final_state_dict[key] = torch.from_numpy(value) else: final_state_dict[key] = value return final_state_dict def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None): """Load and process params by restoring via JAX model loader first. This respects dtype conversions that occur during model restore. """ # Use repository restore utility to load a pure dict of params (value suffix removed) params = openpi.models.model.restore_params( f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision ) return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params} def load_jax_model_and_print_keys(checkpoint_dir: str): """ Load JAX model from checkpoint and print all parameter keys. Args: checkpoint_dir: Path to the checkpoint directory """ checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir # Initialize checkpointer checkpointer = ocp.PyTreeCheckpointer() metadata = checkpointer.metadata(f"{checkpoint_dir}/params") print(utils.array_tree_to_info(metadata)) def convert_pi0_checkpoint( checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config ): """ Convert PI0 JAX checkpoint to PyTorch format. Args: checkpoint_dir: Path to the JAX checkpoint precision: Model precision (float32, bfloat16, float16) output_path: Path to save the converted PyTorch model model_config: Model config """ print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}") print(f"Model config: {model_config}") # Break down orbax ckpts by restoring via JAX to respect dtype initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32") # Process projection params if model_config.pi05: keys = [ "action_in_proj", "action_out_proj", "time_mlp_in", "time_mlp_out", ] else: keys = [ "state_proj", "action_in_proj", "action_out_proj", "action_time_mlp_in", "action_time_mlp_out", ] projection_params = {} for key in keys: kernel_params = initial_params["projection_params"][key]["kernel"] bias_params = initial_params["projection_params"][key]["bias"] if isinstance(kernel_params, dict): weight = kernel_params["value"] bias = bias_params["value"] else: weight = kernel_params bias = bias_params pytorch_weight_key = f"{key}.weight" pytorch_bias_key = f"{key}.bias" projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias)) # Create configs based on checkpoint path # All models use the same PaliGemma config structure class PaliGemmaConfig: def __init__(self): self.vision_config = type( "obj", (object,), { "hidden_size": 1152, "num_hidden_layers": 27, "num_attention_heads": 16, "intermediate_size": 4304, "patch_size": 14, "projection_dim": 2048, }, )() self.text_config = type( "obj", (object,), { "hidden_size": 2048, "num_hidden_layers": 18, "num_attention_heads": 8, "head_dim": 256, "intermediate_size": 16384, }, )() paligemma_config = PaliGemmaConfig() action_expert_config = openpi.models.gemma.get_config("gemma_300m") # Process PaliGemma weights paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config) # Process Gemma weights from expert_params gemma_params = slice_gemma_state_dict( expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05 ) # Instantiate model pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config) # Combine all parameters (no prefix needed for our model structure) all_params = {**paligemma_params, **gemma_params, **projection_params} # Load state dict pi0_model.load_state_dict(all_params, strict=False) if precision == "float32": pi0_model = pi0_model.to(torch.float32) elif precision == "bfloat16": pi0_model = pi0_model.to(torch.bfloat16) else: raise ValueError(f"Invalid precision: {precision}") # Save the converted model using safetensors os.makedirs(output_path, exist_ok=True) # Save model weights as SafeTensors using save_model to handle tied weights safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors")) # Copy assets folder if it exists assets_source = pathlib.Path(checkpoint_dir).parent / "assets" if assets_source.exists(): assets_dest = pathlib.Path(output_path) / "assets" if assets_dest.exists(): shutil.rmtree(assets_dest) shutil.copytree(assets_source, assets_dest) # Save config as JSON for reference config_dict = { "action_dim": model_config.action_dim, "action_horizon": model_config.action_horizon, "paligemma_variant": model_config.paligemma_variant, "action_expert_variant": model_config.action_expert_variant, "precision": precision, } with open(os.path.join(output_path, "config.json"), "w") as f: json.dump(config_dict, f, indent=2) print("Model conversion completed successfully!") print(f"Model saved to {output_path}") def main( checkpoint_dir: str, config_name: str, output_path: str | None = None, precision: Literal["float32", "bfloat16", "float16"] = "bfloat16", *, inspect_only: bool = False, ): """Load JAX model and optionally convert to PyTorch. Args: checkpoint_dir: Path to the JAX checkpoint directory output_path: Path to save converted PyTorch model (required for conversion) precision: Precision for model conversion inspect_only: Only inspect parameter keys, don't convert """ model_config = _config.get_config(config_name).model if not isinstance(model_config, openpi.models.pi0_config.Pi0Config): raise ValueError(f"Config {config_name} is not a Pi0Config") if inspect_only: load_jax_model_and_print_keys(checkpoint_dir) else: if not output_path: print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.") return convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config) if __name__ == "__main__": tyro.cli(main)