| """Compute normalization statistics for a config. |
| |
| This script is used to compute the normalization statistics for a given config. It |
| will compute the mean and standard deviation of the data in the dataset and save it |
| to the config assets directory. |
| """ |
|
|
| import numpy as np |
| import tqdm |
| import tyro |
|
|
| import openpi.models.model as _model |
| import openpi.shared.normalize as normalize |
| import openpi.training.config as _config |
| import openpi.training.data_loader as _data_loader |
| import openpi.transforms as transforms |
|
|
|
|
| class RemoveStrings(transforms.DataTransformFn): |
| def __call__(self, x: dict) -> dict: |
| return {k: v for k, v in x.items() if not np.issubdtype(np.asarray(v).dtype, np.str_)} |
|
|
|
|
| def create_torch_dataloader( |
| data_config: _config.DataConfig, |
| action_horizon: int, |
| batch_size: int, |
| model_config: _model.BaseModelConfig, |
| num_workers: int, |
| max_frames: int | None = None, |
| ) -> tuple[_data_loader.Dataset, int]: |
| if data_config.repo_id is None: |
| raise ValueError("Data config must have a repo_id") |
| dataset = _data_loader.create_torch_dataset(data_config, action_horizon, model_config) |
| dataset = _data_loader.TransformedDataset( |
| dataset, |
| [ |
| *data_config.repack_transforms.inputs, |
| *data_config.data_transforms.inputs, |
| |
| RemoveStrings(), |
| ], |
| ) |
| if max_frames is not None and max_frames < len(dataset): |
| num_batches = max_frames // batch_size |
| shuffle = True |
| else: |
| num_batches = len(dataset) // batch_size |
| shuffle = False |
| data_loader = _data_loader.TorchDataLoader( |
| dataset, |
| local_batch_size=batch_size, |
| num_workers=num_workers, |
| shuffle=shuffle, |
| num_batches=num_batches, |
| ) |
| return data_loader, num_batches |
|
|
|
|
| def create_rlds_dataloader( |
| data_config: _config.DataConfig, |
| action_horizon: int, |
| batch_size: int, |
| max_frames: int | None = None, |
| ) -> tuple[_data_loader.Dataset, int]: |
| dataset = _data_loader.create_rlds_dataset(data_config, action_horizon, batch_size, shuffle=False) |
| dataset = _data_loader.IterableTransformedDataset( |
| dataset, |
| [ |
| *data_config.repack_transforms.inputs, |
| *data_config.data_transforms.inputs, |
| |
| RemoveStrings(), |
| ], |
| is_batched=True, |
| ) |
| if max_frames is not None and max_frames < len(dataset): |
| num_batches = max_frames // batch_size |
| else: |
| |
| num_batches = len(dataset) // batch_size |
| data_loader = _data_loader.RLDSDataLoader( |
| dataset, |
| num_batches=num_batches, |
| ) |
| return data_loader, num_batches |
|
|
|
|
| def main(config_name: str, max_frames: int | None = None): |
| config = _config.get_config(config_name) |
| data_config = config.data.create(config.assets_dirs, config.model) |
|
|
| if data_config.rlds_data_dir is not None: |
| data_loader, num_batches = create_rlds_dataloader( |
| data_config, config.model.action_horizon, config.batch_size, max_frames |
| ) |
| else: |
| data_loader, num_batches = create_torch_dataloader( |
| data_config, config.model.action_horizon, config.batch_size, config.model, config.num_workers, max_frames |
| ) |
|
|
| keys = ["state", "actions"] |
| stats = {key: normalize.RunningStats() for key in keys} |
|
|
| for batch in tqdm.tqdm(data_loader, total=num_batches, desc="Computing stats"): |
| for key in keys: |
| stats[key].update(np.asarray(batch[key])) |
|
|
| norm_stats = {key: stats.get_statistics() for key, stats in stats.items()} |
|
|
| output_path = config.assets_dirs / data_config.repo_id |
| print(f"Writing stats to: {output_path}") |
| normalize.save(output_path, norm_stats) |
|
|
|
|
| if __name__ == "__main__": |
| tyro.cli(main) |
|
|