Spaces:
Running
Running
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import json | |
import pickle | |
from pathlib import Path | |
import torch | |
from lerobot.common.datasets.lerobot_dataset import LeRobotDatasetMetadata | |
from lerobot.common.policies.factory import make_policy | |
from lerobot.configs.policies import PreTrainedConfig | |
def display(tensor: torch.Tensor): | |
if tensor.dtype == torch.bool: | |
tensor = tensor.float() | |
print(f"Shape: {tensor.shape}") | |
print(f"Mean: {tensor.mean().item()}") | |
print(f"Std: {tensor.std().item()}") | |
print(f"Min: {tensor.min().item()}") | |
print(f"Max: {tensor.max().item()}") | |
def main(): | |
num_motors = 14 | |
device = "cuda" | |
# model_name = "pi0_aloha_towel" | |
model_name = "pi0_aloha_sim" | |
if model_name == "pi0_aloha_towel": | |
dataset_repo_id = "lerobot/aloha_static_towel" | |
else: | |
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" | |
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" | |
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" | |
save_dir = Path(f"../openpi/data/{model_name}/save") | |
with open(save_dir / "example.pkl", "rb") as f: | |
example = pickle.load(f) | |
with open(save_dir / "outputs.pkl", "rb") as f: | |
outputs = pickle.load(f) | |
with open(save_dir / "noise.pkl", "rb") as f: | |
noise = pickle.load(f) | |
with open(ckpt_jax_dir / "assets/norm_stats.json") as f: | |
norm_stats = json.load(f) | |
# Override stats | |
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) | |
dataset_meta.stats["observation.state"]["mean"] = torch.tensor( | |
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 | |
) | |
dataset_meta.stats["observation.state"]["std"] = torch.tensor( | |
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 | |
) | |
# Create LeRobot batch from Jax | |
batch = {} | |
for cam_key, uint_chw_array in example["images"].items(): | |
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 | |
batch["observation.state"] = torch.from_numpy(example["state"]) | |
batch["action"] = torch.from_numpy(outputs["actions"]) | |
batch["task"] = example["prompt"] | |
if model_name == "pi0_aloha_towel": | |
del batch["observation.images.cam_low"] | |
elif model_name == "pi0_aloha_sim": | |
batch["observation.images.top"] = batch["observation.images.cam_high"] | |
del batch["observation.images.cam_high"] | |
# Batchify | |
for key in batch: | |
if isinstance(batch[key], torch.Tensor): | |
batch[key] = batch[key].unsqueeze(0) | |
elif isinstance(batch[key], str): | |
batch[key] = [batch[key]] | |
else: | |
raise ValueError(f"{key}, {batch[key]}") | |
# To device | |
for k in batch: | |
if isinstance(batch[k], torch.Tensor): | |
batch[k] = batch[k].to(device=device, dtype=torch.float32) | |
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32) | |
from lerobot.common import policies # noqa | |
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) | |
cfg.pretrained_path = ckpt_torch_dir | |
policy = make_policy(cfg, dataset_meta) | |
# loss_dict = policy.forward(batch, noise=noise, time=time_beta) | |
# loss_dict["loss"].backward() | |
# print("losses") | |
# display(loss_dict["losses_after_forward"]) | |
# print("pi_losses") | |
# display(pi_losses) | |
actions = [] | |
for _ in range(50): | |
action = policy.select_action(batch, noise=noise) | |
actions.append(action) | |
actions = torch.stack(actions, dim=1) | |
pi_actions = batch["action"] | |
print("actions") | |
display(actions) | |
print() | |
print("pi_actions") | |
display(pi_actions) | |
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2)) | |
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2)) | |
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2)) | |
if __name__ == "__main__": | |
main() | |