|
from itertools import product |
|
from pathlib import Path |
|
|
|
import torch |
|
from omegaconf import OmegaConf |
|
|
|
from lerobot.common.datasets.factory import make_dataset |
|
from lerobot.common.policies.factory import make_policy |
|
from lerobot.common.utils.utils import init_hydra_config |
|
|
|
PATH_TO_ORIGINAL_WEIGHTS = "/tmp/dp.pt" |
|
PATH_TO_CONFIG = "/home/alexander/Projects/lerobot/lerobot/configs/default.yaml" |
|
PATH_TO_SAVE_NEW_WEIGHTS = "/tmp/dp" |
|
|
|
cfg = init_hydra_config(PATH_TO_CONFIG) |
|
|
|
policy = make_policy(cfg, dataset_stats=make_dataset(cfg).stats) |
|
|
|
state_dict = torch.load(PATH_TO_ORIGINAL_WEIGHTS) |
|
|
|
|
|
|
|
start_removals = ["normalizer.", "obs_encoder.obs_nets.rgb.backbone.nets.0.nets.0"] |
|
|
|
for to_remove in start_removals: |
|
for k in list(state_dict.keys()): |
|
if k.startswith(to_remove): |
|
del state_dict[k] |
|
|
|
|
|
|
|
|
|
start_replacements = [ |
|
("obs_encoder.obs_nets.image.backbone.nets", "rgb_encoder.backbone"), |
|
("obs_encoder.obs_nets.image.pool", "rgb_encoder.pool"), |
|
("obs_encoder.obs_nets.image.nets.3", "rgb_encoder.out"), |
|
*[(f"model.up_modules.{i}.2.conv.", f"model.up_modules.{i}.2.") for i in range(2)], |
|
*[(f"model.down_modules.{i}.2.conv.", f"model.down_modules.{i}.2.") for i in range(2)], |
|
*[ |
|
(f"model.mid_modules.{i}.blocks.{k}.", f"model.mid_modules.{i}.conv{k + 1}.") |
|
for i, k in product(range(3), range(2)) |
|
], |
|
*[ |
|
(f"model.down_modules.{i}.{j}.blocks.{k}.", f"model.down_modules.{i}.{j}.conv{k + 1}.") |
|
for i, j, k in product(range(3), range(2), range(2)) |
|
], |
|
*[ |
|
(f"model.up_modules.{i}.{j}.blocks.{k}.", f"model.up_modules.{i}.{j}.conv{k + 1}.") |
|
for i, j, k in product(range(3), range(2), range(2)) |
|
], |
|
("model.", "unet.") |
|
] |
|
|
|
for to_replace, replace_with in start_replacements: |
|
for k in list(state_dict.keys()): |
|
if k.startswith(to_replace): |
|
k_ = replace_with + k.removeprefix(to_replace) |
|
state_dict[k_] = state_dict[k] |
|
del state_dict[k] |
|
|
|
missing_keys, unexpected_keys = policy.diffusion.load_state_dict(state_dict, strict=False) |
|
|
|
unexpected_keys = set(unexpected_keys) |
|
allowed_unexpected_keys = eval( |
|
"{'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.1.nets.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.1.pos_x', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.1.nets.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.weight', '_dummy_variable', 'mask_generator._dummy_variable', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.temperature', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.pos_y', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.bias'}" |
|
) |
|
if len(missing_keys) != 0: |
|
print("MISSING KEYS") |
|
print(missing_keys) |
|
if unexpected_keys != allowed_unexpected_keys: |
|
print("UNEXPECTED KEYS") |
|
print(unexpected_keys) |
|
|
|
if len(missing_keys) != 0 or unexpected_keys != allowed_unexpected_keys: |
|
print("Failed due to mismatch in state dicts.") |
|
exit() |
|
|
|
torch.save(policy.state_dict(), "/tmp/policy.pt") |
|
policy.save_pretrained(PATH_TO_SAVE_NEW_WEIGHTS) |
|
OmegaConf.save(cfg, Path(PATH_TO_SAVE_NEW_WEIGHTS) / "config.yaml") |