|
|
|
""" |
|
Load a JAX model and print all parameter keys, with optional conversion to PyTorch. |
|
|
|
This script loads a JAX model checkpoint using orbax and can either: |
|
1. Print out all the parameter keys in a hierarchical structure for inspection |
|
2. Convert the JAX model to PyTorch format using our PI0Pytorch model |
|
|
|
Usage: |
|
# Just inspect keys: |
|
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only |
|
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --inspect_only |
|
|
|
# Convert to PyTorch: |
|
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output |
|
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /path/to/checkpoint --output_path /path/to/output |
|
|
|
Example: |
|
# pi0_droid |
|
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_droid_pytorch |
|
|
|
# pi0_aloha_sim |
|
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch |
|
|
|
# pi05_droid |
|
python examples/convert_jax_model_to_pytorch.py --checkpoint_dir /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid/params --output_path /home/$USER/.cache/openpi/openpi-assets/checkpoints/pi05_droid_pytorch |
|
""" |
|
|
|
import json |
|
import os |
|
import pathlib |
|
import shutil |
|
from typing import Literal |
|
|
|
from flax.nnx import traversals |
|
import numpy as np |
|
import orbax.checkpoint as ocp |
|
import safetensors |
|
import torch |
|
import tyro |
|
|
|
import openpi.models.gemma |
|
import openpi.models.model |
|
import openpi.models.pi0_config |
|
import openpi.models_pytorch.pi0_pytorch |
|
from openpi.training import utils |
|
import openpi.training.config as _config |
|
|
|
|
|
def slice_paligemma_state_dict(state_dict, config): |
|
"""Convert PaliGemma JAX parameters to PyTorch format.""" |
|
suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" |
|
|
|
|
|
jax_key = f"img/embedding/kernel{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.weight" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose(3, 2, 0, 1) |
|
|
|
jax_key = f"img/embedding/bias{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.patch_embedding.bias" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
|
|
|
jax_key = f"img/pos_embedding{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.embeddings.position_embedding.weight" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key).reshape(-1, config.vision_config.hidden_size) |
|
|
|
|
|
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") |
|
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") |
|
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") |
|
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") |
|
|
|
encoderblock_mlp_dense0_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") |
|
encoderblock_mlp_dense0_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") |
|
encoderblock_mlp_dense1_kernel = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") |
|
encoderblock_mlp_dense1_bias = state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") |
|
|
|
encoderblock_attention_0_key_kernel = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}" |
|
) |
|
encoderblock_attention_0_key_bias = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}" |
|
) |
|
encoderblock_attention_0_value_kernel = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}" |
|
) |
|
encoderblock_attention_0_value_bias = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}" |
|
) |
|
encoderblock_attention_0_query_kernel = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}" |
|
) |
|
encoderblock_attention_0_query_bias = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}" |
|
) |
|
encoderblock_attention_0_out_kernel = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}" |
|
) |
|
encoderblock_attention_0_out_bias = state_dict.pop( |
|
f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}" |
|
) |
|
|
|
for i in range(config.vision_config.num_hidden_layers): |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight" |
|
] = encoderblock_layernorm0_scale[i].transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias" |
|
] = encoderblock_layernorm0_bias[i] |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight" |
|
] = encoderblock_layernorm1_scale[i].transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias" |
|
] = encoderblock_layernorm1_bias[i] |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight" |
|
] = encoderblock_mlp_dense0_kernel[i].transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias" |
|
] = encoderblock_mlp_dense0_bias[i] |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight" |
|
] = encoderblock_mlp_dense1_kernel[i].transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias" |
|
] = encoderblock_mlp_dense1_bias[i] |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight" |
|
] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias" |
|
] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight" |
|
] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias" |
|
] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight" |
|
] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias" |
|
] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight" |
|
] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias" |
|
] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) |
|
|
|
jax_key = f"img/Transformer/encoder_norm/scale{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.weight" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() |
|
|
|
jax_key = f"img/Transformer/encoder_norm/bias{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.vision_tower.vision_model.post_layernorm.bias" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
|
|
|
jax_key = f"img/head/kernel{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.weight" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key).transpose() |
|
|
|
jax_key = f"img/head/bias{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.multi_modal_projector.linear.bias" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
|
|
|
jax_key = f"llm/embedder/input_embedding{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
|
|
|
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") |
|
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") |
|
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") |
|
|
|
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") |
|
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") |
|
|
|
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") |
|
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") |
|
|
|
for i in range(config.text_config.num_hidden_layers): |
|
q_proj_weight_reshaped = ( |
|
llm_attention_q_einsum[i] |
|
.transpose(0, 2, 1) |
|
.reshape( |
|
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size |
|
) |
|
) |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.q_proj.weight"] = ( |
|
q_proj_weight_reshaped |
|
) |
|
|
|
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.k_proj.weight"] = ( |
|
k_proj_weight_reshaped |
|
) |
|
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.v_proj.weight"] = ( |
|
v_proj_weight_reshaped |
|
) |
|
|
|
o_proj_weight_reshaped = ( |
|
llm_attention_attn_vec_einsum[i] |
|
.transpose(2, 0, 1) |
|
.reshape( |
|
config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size |
|
) |
|
) |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.self_attn.o_proj.weight"] = ( |
|
o_proj_weight_reshaped |
|
) |
|
|
|
gate_proj_weight = llm_mlp_gating_einsum[i, 0] |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.gate_proj.weight"] = ( |
|
gate_proj_weight.transpose() |
|
) |
|
up_proj_weight = llm_mlp_gating_einsum[i, 1] |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.up_proj.weight"] = ( |
|
up_proj_weight.transpose() |
|
) |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.mlp.down_proj.weight"] = ( |
|
llm_mlp_linear[i].transpose() |
|
) |
|
state_dict[f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.input_layernorm.weight"] = ( |
|
llm_input_layernorm[i] |
|
) |
|
state_dict[ |
|
f"paligemma_with_expert.paligemma.model.language_model.layers.{i}.post_attention_layernorm.weight" |
|
] = llm_post_attention_layernorm[i] |
|
|
|
jax_key = f"llm/final_norm/scale{suffix}" |
|
pytorch_key = "paligemma_with_expert.paligemma.model.language_model.norm.weight" |
|
state_dict[pytorch_key] = state_dict.pop(jax_key) |
|
|
|
expert_dict = {} |
|
final_state_dict = {} |
|
|
|
|
|
expert_keys = [ |
|
f"llm/final_norm_1/scale{suffix}", |
|
f"llm/final_norm_1/Dense_0/bias{suffix}", |
|
f"llm/final_norm_1/Dense_0/kernel{suffix}", |
|
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", |
|
f"llm/layers/attn/kv_einsum_1/w{suffix}", |
|
f"llm/layers/attn/q_einsum_1/w{suffix}", |
|
f"llm/layers/mlp_1/gating_einsum{suffix}", |
|
f"llm/layers/mlp_1/linear{suffix}", |
|
f"llm/layers/pre_attention_norm_1/scale{suffix}", |
|
f"llm/layers/pre_attention_norm_1/Dense_0/bias{suffix}", |
|
f"llm/layers/pre_attention_norm_1/Dense_0/kernel{suffix}", |
|
f"llm/layers/pre_ffw_norm_1/scale{suffix}", |
|
f"llm/layers/pre_ffw_norm_1/Dense_0/bias{suffix}", |
|
f"llm/layers/pre_ffw_norm_1/Dense_0/kernel{suffix}", |
|
] |
|
|
|
for key, value in state_dict.items(): |
|
if key not in expert_keys: |
|
final_state_dict[key] = torch.from_numpy(value) |
|
else: |
|
expert_dict[key] = value |
|
|
|
return final_state_dict, expert_dict |
|
|
|
|
|
def slice_gemma_state_dict(state_dict, config, *, num_expert, checkpoint_dir, pi05): |
|
"""Convert Gemma JAX parameters to PyTorch format.""" |
|
|
|
if not hasattr(config, "vocab_size"): |
|
config.vocab_size = 257152 |
|
if not hasattr(config, "hidden_size"): |
|
config.hidden_size = config.width |
|
if not hasattr(config, "num_hidden_layers"): |
|
config.num_hidden_layers = config.depth |
|
if not hasattr(config, "num_attention_heads"): |
|
config.num_attention_heads = config.num_heads |
|
|
|
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" |
|
|
|
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") |
|
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") |
|
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") |
|
|
|
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") |
|
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") |
|
|
|
|
|
if "pi05" in checkpoint_dir: |
|
|
|
llm_input_layernorm_bias = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/bias{suffix}") |
|
llm_post_attention_layernorm_bias = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/bias{suffix}") |
|
llm_input_layernorm_kernel = state_dict.pop( |
|
f"llm/layers/pre_attention_norm_{num_expert}/Dense_0/kernel{suffix}" |
|
) |
|
llm_post_attention_layernorm_kernel = state_dict.pop( |
|
f"llm/layers/pre_ffw_norm_{num_expert}/Dense_0/kernel{suffix}" |
|
) |
|
else: |
|
|
|
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") |
|
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") |
|
|
|
for i in range(config.num_hidden_layers): |
|
q_proj_weight_reshaped = ( |
|
llm_attention_q_einsum[i] |
|
.transpose(0, 2, 1) |
|
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size) |
|
) |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = ( |
|
q_proj_weight_reshaped |
|
) |
|
|
|
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = ( |
|
k_proj_weight_reshaped |
|
) |
|
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = ( |
|
v_proj_weight_reshaped |
|
) |
|
|
|
o_proj_weight_reshaped = ( |
|
llm_attention_attn_vec_einsum[i] |
|
.reshape(config.num_attention_heads * config.head_dim, config.hidden_size) |
|
.transpose(1, 0) |
|
) |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = ( |
|
o_proj_weight_reshaped |
|
) |
|
|
|
gate_proj_weight = llm_mlp_gating_einsum[i, 0] |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = ( |
|
gate_proj_weight.transpose() |
|
) |
|
up_proj_weight = llm_mlp_gating_einsum[i, 1] |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = ( |
|
up_proj_weight.transpose() |
|
) |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[ |
|
i |
|
].transpose() |
|
|
|
if "pi05" in checkpoint_dir: |
|
|
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.bias"] = ( |
|
llm_input_layernorm_bias[i] |
|
) |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.bias"] = ( |
|
llm_post_attention_layernorm_bias[i] |
|
) |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.dense.weight"] = ( |
|
llm_input_layernorm_kernel[i].transpose() |
|
) |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.dense.weight"] = ( |
|
llm_post_attention_layernorm_kernel[i].transpose() |
|
) |
|
else: |
|
|
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.input_layernorm.weight"] = ( |
|
llm_input_layernorm[i] |
|
) |
|
state_dict[f"paligemma_with_expert.gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = ( |
|
llm_post_attention_layernorm[i] |
|
) |
|
|
|
|
|
if "pi05" in checkpoint_dir: |
|
|
|
final_norm_bias = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/bias{suffix}") |
|
final_norm_kernel = state_dict.pop(f"llm/final_norm_{num_expert}/Dense_0/kernel{suffix}") |
|
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.bias"] = final_norm_bias |
|
state_dict["paligemma_with_expert.gemma_expert.model.norm.dense.weight"] = final_norm_kernel.transpose() |
|
else: |
|
|
|
state_dict["paligemma_with_expert.gemma_expert.model.norm.weight"] = state_dict.pop( |
|
f"llm/final_norm_{num_expert}/scale{suffix}" |
|
) |
|
|
|
|
|
|
|
final_state_dict = {} |
|
for key, value in state_dict.items(): |
|
if not isinstance(value, torch.Tensor): |
|
final_state_dict[key] = torch.from_numpy(value) |
|
else: |
|
final_state_dict[key] = value |
|
|
|
return final_state_dict |
|
|
|
|
|
def slice_initial_orbax_checkpoint(checkpoint_dir: str, restore_precision: str | None = None): |
|
"""Load and process params by restoring via JAX model loader first. |
|
This respects dtype conversions that occur during model restore. |
|
""" |
|
|
|
params = openpi.models.model.restore_params( |
|
f"{checkpoint_dir}/params/", restore_type=np.ndarray, dtype=restore_precision |
|
) |
|
|
|
return {"paligemma_params": traversals.flatten_mapping(params["PaliGemma"], sep="/"), "projection_params": params} |
|
|
|
|
|
def load_jax_model_and_print_keys(checkpoint_dir: str): |
|
""" |
|
Load JAX model from checkpoint and print all parameter keys. |
|
|
|
Args: |
|
checkpoint_dir: Path to the checkpoint directory |
|
""" |
|
checkpoint_dir = os.path.abspath(checkpoint_dir) if not checkpoint_dir.startswith("gs://") else checkpoint_dir |
|
|
|
checkpointer = ocp.PyTreeCheckpointer() |
|
metadata = checkpointer.metadata(f"{checkpoint_dir}/params") |
|
print(utils.array_tree_to_info(metadata)) |
|
|
|
|
|
def convert_pi0_checkpoint( |
|
checkpoint_dir: str, precision: str, output_path: str, model_config: openpi.models.pi0_config.Pi0Config |
|
): |
|
""" |
|
Convert PI0 JAX checkpoint to PyTorch format. |
|
|
|
Args: |
|
checkpoint_dir: Path to the JAX checkpoint |
|
precision: Model precision (float32, bfloat16, float16) |
|
output_path: Path to save the converted PyTorch model |
|
model_config: Model config |
|
""" |
|
print(f"Converting PI0 checkpoint from {checkpoint_dir} to {output_path}") |
|
print(f"Model config: {model_config}") |
|
|
|
|
|
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir, restore_precision="float32") |
|
|
|
|
|
if model_config.pi05: |
|
keys = [ |
|
"action_in_proj", |
|
"action_out_proj", |
|
"time_mlp_in", |
|
"time_mlp_out", |
|
] |
|
else: |
|
keys = [ |
|
"state_proj", |
|
"action_in_proj", |
|
"action_out_proj", |
|
"action_time_mlp_in", |
|
"action_time_mlp_out", |
|
] |
|
|
|
projection_params = {} |
|
for key in keys: |
|
kernel_params = initial_params["projection_params"][key]["kernel"] |
|
bias_params = initial_params["projection_params"][key]["bias"] |
|
if isinstance(kernel_params, dict): |
|
weight = kernel_params["value"] |
|
bias = bias_params["value"] |
|
else: |
|
weight = kernel_params |
|
bias = bias_params |
|
|
|
pytorch_weight_key = f"{key}.weight" |
|
pytorch_bias_key = f"{key}.bias" |
|
|
|
projection_params[pytorch_weight_key] = torch.from_numpy(np.array(weight)).T |
|
projection_params[pytorch_bias_key] = torch.from_numpy(np.array(bias)) |
|
|
|
|
|
|
|
class PaliGemmaConfig: |
|
def __init__(self): |
|
self.vision_config = type( |
|
"obj", |
|
(object,), |
|
{ |
|
"hidden_size": 1152, |
|
"num_hidden_layers": 27, |
|
"num_attention_heads": 16, |
|
"intermediate_size": 4304, |
|
"patch_size": 14, |
|
"projection_dim": 2048, |
|
}, |
|
)() |
|
self.text_config = type( |
|
"obj", |
|
(object,), |
|
{ |
|
"hidden_size": 2048, |
|
"num_hidden_layers": 18, |
|
"num_attention_heads": 8, |
|
"head_dim": 256, |
|
"intermediate_size": 16384, |
|
}, |
|
)() |
|
|
|
paligemma_config = PaliGemmaConfig() |
|
action_expert_config = openpi.models.gemma.get_config("gemma_300m") |
|
|
|
|
|
paligemma_params, expert_params = slice_paligemma_state_dict(initial_params["paligemma_params"], paligemma_config) |
|
|
|
|
|
gemma_params = slice_gemma_state_dict( |
|
expert_params, action_expert_config, num_expert=1, checkpoint_dir=checkpoint_dir, pi05=model_config.pi05 |
|
) |
|
|
|
|
|
pi0_model = openpi.models_pytorch.pi0_pytorch.PI0Pytorch(model_config) |
|
|
|
|
|
all_params = {**paligemma_params, **gemma_params, **projection_params} |
|
|
|
|
|
pi0_model.load_state_dict(all_params, strict=False) |
|
|
|
if precision == "float32": |
|
pi0_model = pi0_model.to(torch.float32) |
|
elif precision == "bfloat16": |
|
pi0_model = pi0_model.to(torch.bfloat16) |
|
else: |
|
raise ValueError(f"Invalid precision: {precision}") |
|
|
|
|
|
os.makedirs(output_path, exist_ok=True) |
|
|
|
|
|
safetensors.torch.save_model(pi0_model, os.path.join(output_path, "model.safetensors")) |
|
|
|
|
|
assets_source = pathlib.Path(checkpoint_dir).parent / "assets" |
|
if assets_source.exists(): |
|
assets_dest = pathlib.Path(output_path) / "assets" |
|
if assets_dest.exists(): |
|
shutil.rmtree(assets_dest) |
|
shutil.copytree(assets_source, assets_dest) |
|
|
|
|
|
config_dict = { |
|
"action_dim": model_config.action_dim, |
|
"action_horizon": model_config.action_horizon, |
|
"paligemma_variant": model_config.paligemma_variant, |
|
"action_expert_variant": model_config.action_expert_variant, |
|
"precision": precision, |
|
} |
|
with open(os.path.join(output_path, "config.json"), "w") as f: |
|
json.dump(config_dict, f, indent=2) |
|
|
|
print("Model conversion completed successfully!") |
|
print(f"Model saved to {output_path}") |
|
|
|
|
|
def main( |
|
checkpoint_dir: str, |
|
config_name: str, |
|
output_path: str | None = None, |
|
precision: Literal["float32", "bfloat16", "float16"] = "bfloat16", |
|
*, |
|
inspect_only: bool = False, |
|
): |
|
"""Load JAX model and optionally convert to PyTorch. |
|
|
|
Args: |
|
checkpoint_dir: Path to the JAX checkpoint directory |
|
output_path: Path to save converted PyTorch model (required for conversion) |
|
precision: Precision for model conversion |
|
inspect_only: Only inspect parameter keys, don't convert |
|
""" |
|
model_config = _config.get_config(config_name).model |
|
if not isinstance(model_config, openpi.models.pi0_config.Pi0Config): |
|
raise ValueError(f"Config {config_name} is not a Pi0Config") |
|
if inspect_only: |
|
load_jax_model_and_print_keys(checkpoint_dir) |
|
else: |
|
if not output_path: |
|
print("Error: --output_path is required for conversion. Use --inspect_only to only view keys.") |
|
return |
|
convert_pi0_checkpoint(checkpoint_dir, precision, output_path, model_config) |
|
|
|
|
|
if __name__ == "__main__": |
|
tyro.cli(main) |
|
|