openpi / examples /convert_jax_model_to_pytorch.py
Kevin Black
update paths
2d70d96
#!/usr/bin/env python3
"""
Load a JAX model and print all parameter keys, with optional conversion to PyTorch.
This script loads a JAX model checkpoint using orbax and can either:
1. Print out all the parameter keys in a hierarchical structure for inspection
2. Convert the JAX model to PyTorch format using our PI0Pytorch model
Usage:
# Just inspect keys:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only
# Convert to PyTorch:
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output
Example:
# pi0_droid
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch
# pi0_aloha_sim
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
# pi05_droid
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch
"""
import json
import os
import pathlib
import shutil
from typing import Literal
from flax.nnx import traversals
import numpy as np
import orbax.checkpoint as ocp
import safetensors
import torch
import tyro
import openpi.models.gemma
import openpi.models.model
import openpi.models.pi0_config
import openpi.models_pytorch.pi0_pytorch
from openpi.training import utils
import openpi.training.config as _config
def slice_paligemma_state_dict(state_dict, config):
"""Convert PaliGemma JAX parameters to PyTorch format."""
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# patch embeddings
jax_key = f"img/embedding/kernel{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1)
jax_key = f"img/embedding/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# positional embeddings
jax_key = f"img/pos_embedding{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}"
)
encoderblock_attention_0_key_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}"
)
encoderblock_attention_0_value_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}"
)
encoderblock_attention_0_value_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}"
)
encoderblock_attention_0_query_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}"
)
encoderblock_attention_0_query_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}"
)
encoderblock_attention_0_out_kernel = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}"
)
encoderblock_attention_0_out_bias = state_dict.pop(
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}"
)
for i in range(config.vision_config.num_hidden_layers):
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"
] = encoderblock_layernorm0_scale[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"
] = encoderblock_layernorm0_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"
] = encoderblock_layernorm1_scale[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"
] = encoderblock_layernorm1_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"
] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"
] = encoderblock_mlp_dense0_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"
] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"
] = encoderblock_mlp_dense1_bias[i]
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
jax_key = f"img/Transformer/encoder_norm/scale{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
jax_key = f"img/Transformer/encoder_norm/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# multimodal projector
jax_key = f"img/head/kernel{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose()
jax_key = f"img/head/bias{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# text decoder (gemma)
jax_key = f"llm/embedder/input_embedding{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key)
# pop the einsum attention + mlp representations
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
q_proj_weight_reshaped = (
llm_attention_q_einsum[i]
.transpose(0, 2, 1)
.reshape(
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
)
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = (
q_proj_weight_reshaped
)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = (
k_proj_weight_reshaped
)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = (
v_proj_weight_reshaped
)
o_proj_weight_reshaped = (
llm_attention_attn_vec_einsum[i]
.transpose(2, 0, 1)
.reshape(
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size
)
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = (
o_proj_weight_reshaped
)
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = (
gate_proj_weight.transpose()
)
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = (
up_proj_weight.transpose()
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = (
llm_mlp_linear[i].transpose()
)
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = (
llm_input_layernorm[i]
)
state_dict[
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight"
] = llm_post_attention_layernorm[i]
jax_key = f"llm/final_norm/scale{suffix}"
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight"
state_dict[pytorch_key] = state_dict.pop(jax_key)
expert_dict = {}
final_state_dict = {}
# Expert-related keys to extract (including pi05 Dense layer parameters)
expert_keys = [
f"llm/final_norm_1/scale{suffix}",
f"llm/final_norm_1/Dense_0/bias{suffix}",
f"llm/final_norm_1/Dense_0/kernel{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}",
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}",
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}",
]
for key, value in state_dict.items():
if key not in expert_keys:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05):
"""Convert Gemma JAX parameters to PyTorch format."""
# Add missing attributes to config if they don't exist
if not hasattr(config, "vocab_size"):
config.vocab_size = 257152 # PALIGEMMA_VOCAB_SIZE
if not hasattr(config, "hidden_size"):
config.hidden_size = config.width
if not hasattr(config, "num_hidden_layers"):
config.num_hidden_layers = config.depth
if not hasattr(config, "num_attention_heads"):
config.num_attention_heads = config.num_heads
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# Check if we have Dense layers (for pi05/adaptive normalization) or scale layers (for regular pi0)
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}")
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}")
llm_input_layernorm_kernel = state_dict.pop(
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}"
)
llm_post_attention_layernorm_kernel = state_dict.pop(
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}"
)
else:
# Regular pi0 with standard RMSNorm
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = (
llm_attention_q_einsum[i]
.transpose(0, 2, 1)
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = (
q_proj_weight_reshaped
)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = (
k_proj_weight_reshaped
)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = (
v_proj_weight_reshaped
)
o_proj_weight_reshaped = (
llm_attention_attn_vec_einsum[i]
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
.transpose(1, 0)
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = (
o_proj_weight_reshaped
)
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = (
gate_proj_weight.transpose()
)
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = (
up_proj_weight.transpose()
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[
i
].transpose()
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization - use Dense layer parameters directly
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = (
llm_input_layernorm_bias[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = (
llm_post_attention_layernorm_bias[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = (
llm_input_layernorm_kernel[i].transpose()
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = (
llm_post_attention_layernorm_kernel[i].transpose()
)
else:
# Regular pi0 with standard RMSNorm
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = (
llm_input_layernorm[i]
)
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = (
llm_post_attention_layernorm[i]
)
# Handle final norm layer
if "pi05" in checkpoint_dir:
# Pi05 with adaptive normalization - use Dense layer parameters directly
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}")
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}")
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose()
else:
# Regular pi0 with standard RMSNorm
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop(
f"llm/final_norm_{num_expert}/scale{suffix}"
)
# state_dict["paligemma_with_expert.gemma_expert.lm_head.weight"] = embedding_vector # weights are tied.
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None):
"""Load and process params by restoring via JAX model loader first.
This respects dtype conversions that occur during model restore.
"""
# Use repository restore utility to load a pure dict of params (value suffix removed)
params = openpi.models.model.restore_params(
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision
)
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params}
def load_jax_model_and_print_keys(checkpoint_dir: str):
"""
Load JAX model from checkpoint and print all parameter keys.
Args:
checkpoint_dir: Path to the checkpoint directory
"""
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir
# Initialize checkpointer
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(f"{checkpoint_dir}/params")
print(utils.array_tree_to_info(metadata))
def convert_pi0_checkpoint(
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config
):
"""
Convert PI0 JAX checkpoint to PyTorch format.
Args:
checkpoint_dir: Path to the JAX checkpoint
precision: Model precision (float32, bfloat16, float16)
output_path: Path to save the converted PyTorch model
model_config: Model config
"""
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}")
print(f"Model config: {model_config}")
# Break down orbax ckpts by restoring via JAX to respect dtype
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32")
# Process projection params
if model_config.pi05:
keys = [
"action_in_proj",
"action_out_proj",
"time_mlp_in",
"time_mlp_out",
]
else:
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
pytorch_weight_key = f"{key}.weight"
pytorch_bias_key = f"{key}.bias"
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias))
# Create configs based on checkpoint path
# All models use the same PaliGemma config structure
class PaliGemmaConfig:
def __init__(self):
self.vision_config = type(
"obj",
(object,),
{
"hidden_size": 1152,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"intermediate_size": 4304,
"patch_size": 14,
"projection_dim": 2048,
},
)()
self.text_config = type(
"obj",
(object,),
{
"hidden_size": 2048,
"num_hidden_layers": 18,
"num_attention_heads": 8,
"head_dim": 256,
"intermediate_size": 16384,
},
)()
paligemma_config = PaliGemmaConfig()
action_expert_config = openpi.models.gemma.get_config("gemma_300m")
# Process PaliGemma weights
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config)
# Process Gemma weights from expert_params
gemma_params = slice_gemma_state_dict(
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05
)
# Instantiate model
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config)
# Combine all parameters (no prefix needed for our model structure)
all_params = {**paligemma_params, **gemma_params, **projection_params}
# Load state dict
pi0_model.load_state_dict(all_params, strict=False)
if precision == "float32":
pi0_model = pi0_model.to(torch.float32)
elif precision == "bfloat16":
pi0_model = pi0_model.to(torch.bfloat16)
else:
raise ValueError(f"Invalid precision: {precision}")
# Save the converted model using safetensors
os.makedirs(output_path, exist_ok=True)
# Save model weights as SafeTensors using save_model to handle tied weights
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors"))
# Copy assets folder if it exists
assets_source = pathlib.Path(checkpoint_dir).parent / "assets"
if assets_source.exists():
assets_dest = pathlib.Path(output_path) / "assets"
if assets_dest.exists():
shutil.rmtree(assets_dest)
shutil.copytree(assets_source, assets_dest)
# Save config as JSON for reference
config_dict = {
"action_dim": model_config.action_dim,
"action_horizon": model_config.action_horizon,
"paligemma_variant": model_config.paligemma_variant,
"action_expert_variant": model_config.action_expert_variant,
"precision": precision,
}
with open(os.path.join(output_path, "config.json"), "w") as f:
json.dump(config_dict, f, indent=2)
print("Model conversion completed successfully!")
print(f"Model saved to {output_path}")
def main(
checkpoint_dir: str,
config_name: str,
output_path: str | None = None,
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16",
*,
inspect_only: bool = False,
):
"""Load JAX model and optionally convert to PyTorch.
Args:
checkpoint_dir: Path to the JAX checkpoint directory
output_path: Path to save converted PyTorch model (required for conversion)
precision: Precision for model conversion
inspect_only: Only inspect parameter keys, don't convert
"""
model_config = _config.get_config(config_name).model
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config):
raise ValueError(f"Config {config_name} is not a Pi0Config")
if inspect_only:
load_jax_model_and_print_keys(checkpoint_dir)
else:
if not output_path:
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.")
return
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config)
if __name__ == "__main__":
tyro.cli(main)