| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
|
|
| import numpy as np |
| import torch |
| from pydantic import Field |
|
|
| from ..schema import DatasetMetadata, StateActionMetadata |
| from .base import InvertibleModalityTransform |
|
|
|
|
| class ConcatTransform(InvertibleModalityTransform): |
| """ |
| Concatenate the keys according to specified order. |
| """ |
|
|
| |
| apply_to: list[str] = Field( |
| default_factory=list, description="Not used in this transform, kept for compatibility." |
| ) |
|
|
| video_concat_order: list[str] = Field( |
| ..., |
| description="Concatenation order for each video modality. " |
| "Format: ['video.ego_view_pad_res224_freq20', ...]", |
| ) |
|
|
| state_concat_order: Optional[list[str]] = Field( |
| default=None, |
| description="Concatenation order for each state modality. " |
| "Format: ['state.position', 'state.velocity', ...].", |
| ) |
|
|
| action_concat_order: Optional[list[str]] = Field( |
| default=None, |
| description="Concatenation order for each action modality. " |
| "Format: ['action.position', 'action.velocity', ...].", |
| ) |
|
|
| action_dims: dict[str, int] = Field( |
| default_factory=dict, |
| description="The dimensions of the action keys.", |
| ) |
| state_dims: dict[str, int] = Field( |
| default_factory=dict, |
| description="The dimensions of the state keys.", |
| ) |
|
|
| def model_dump(self, *args, **kwargs): |
| if kwargs.get("mode", "python") == "json": |
| include = { |
| "apply_to", |
| "video_concat_order", |
| "state_concat_order", |
| "action_concat_order", |
| } |
| else: |
| include = kwargs.pop("include", None) |
|
|
| return super().model_dump(*args, include=include, **kwargs) |
|
|
| def apply(self, data: dict) -> dict: |
| grouped_keys = {} |
| for key in data.keys(): |
| try: |
| modality, _ = key.split(".") |
| except: |
| |
| if "annotation" in key: |
| modality = "language" |
| else: |
| modality = "others" |
| if modality not in grouped_keys: |
| grouped_keys[modality] = [] |
| grouped_keys[modality].append(key) |
|
|
| if "video" in grouped_keys: |
| |
| |
| video_keys = grouped_keys["video"] |
| assert self.video_concat_order is not None, f"{self.video_concat_order=}, {video_keys=}" |
| assert all( |
| item in video_keys for item in self.video_concat_order |
| ), f"keys in video_concat_order are misspecified, \n{video_keys=}, \n{self.video_concat_order=}" |
|
|
| |
| unsqueezed_videos = [] |
| for video_key in self.video_concat_order: |
| video_data = data.pop(video_key) |
| unsqueezed_video = np.expand_dims( |
| video_data, axis=-4 |
| ) |
| unsqueezed_videos.append(unsqueezed_video) |
| |
| unsqueezed_video = np.concatenate(unsqueezed_videos, axis=-4) |
|
|
| |
| data["video"] = unsqueezed_video |
|
|
| |
| if "state" in grouped_keys: |
| state_keys = grouped_keys["state"] |
| assert self.state_concat_order is not None, f"{self.state_concat_order=}" |
| assert all( |
| item in state_keys for item in self.state_concat_order |
| ), f"keys in state_concat_order are misspecified, \n{state_keys=}, \n{self.state_concat_order=}" |
| |
| for key in self.state_concat_order: |
| target_shapes = [self.state_dims[key]] |
| if self.is_rotation_key(key): |
| target_shapes.append(6) |
| |
| target_shapes.append(self.state_dims[key] * 2) |
| assert ( |
| data[key].shape[-1] in target_shapes |
| ), f"State dim mismatch for {key=}, {data[key].shape[-1]=}, {target_shapes=}" |
| |
| |
| data["state"] = torch.cat( |
| [data.pop(key) for key in self.state_concat_order], dim=-1 |
| ) |
|
|
| if "action" in grouped_keys: |
| action_keys = grouped_keys["action"] |
| assert self.action_concat_order is not None, f"{self.action_concat_order=}" |
| |
| assert set(self.action_concat_order) == set( |
| action_keys |
| ), f"{set(self.action_concat_order)=}, {set(action_keys)=}" |
| |
| for key in self.action_concat_order: |
| target_shapes = [self.action_dims[key]] |
| if self.is_rotation_key(key): |
| target_shapes.append(3) |
| assert ( |
| self.action_dims[key] == data[key].shape[-1] |
| ), f"Action dim mismatch for {key=}, {self.action_dims[key]=}, {data[key].shape[-1]=}" |
| |
| |
| data["action"] = torch.cat( |
| [data.pop(key) for key in self.action_concat_order], dim=-1 |
| ) |
|
|
| return data |
|
|
| def unapply(self, data: dict) -> dict: |
| start_dim = 0 |
| assert "action" in data, f"{data.keys()=}" |
| |
| assert self.action_concat_order is not None, f"{self.action_concat_order=}" |
| action_tensor = data.pop("action") |
| for key in self.action_concat_order: |
| if key not in self.action_dims: |
| raise ValueError(f"Action dim {key} not found in action_dims.") |
| end_dim = start_dim + self.action_dims[key] |
| data[key] = action_tensor[..., start_dim:end_dim] |
| start_dim = end_dim |
| if "state" in data: |
| assert self.state_concat_order is not None, f"{self.state_concat_order=}" |
| start_dim = 0 |
| state_tensor = data.pop("state") |
| for key in self.state_concat_order: |
| end_dim = start_dim + self.state_dims[key] |
| data[key] = state_tensor[..., start_dim:end_dim] |
| start_dim = end_dim |
| return data |
|
|
| def __call__(self, data: dict) -> dict: |
| return self.apply(data) |
|
|
| def get_modality_metadata(self, key: str) -> StateActionMetadata: |
| modality, subkey = key.split(".") |
| assert self.dataset_metadata is not None, "Metadata not set" |
| modality_config = getattr(self.dataset_metadata.modalities, modality) |
| assert subkey in modality_config, f"{subkey=} not found in {modality_config=}" |
| assert isinstance( |
| modality_config[subkey], StateActionMetadata |
| ), f"Expected {StateActionMetadata} for {subkey=}, got {type(modality_config[subkey])=}" |
| return modality_config[subkey] |
|
|
| def get_state_action_dims(self, key: str) -> int: |
| """Get the dimension of a state or action key from the dataset metadata.""" |
| modality_config = self.get_modality_metadata(key) |
| shape = modality_config.shape |
| assert len(shape) == 1, f"{shape=}" |
| return shape[0] |
|
|
| def is_rotation_key(self, key: str) -> bool: |
| modality_config = self.get_modality_metadata(key) |
| return modality_config.rotation_type is not None |
|
|
| def set_metadata(self, dataset_metadata: DatasetMetadata): |
| """Set the metadata and compute the dimensions of the state and action keys.""" |
| super().set_metadata(dataset_metadata) |
| |
| if self.action_concat_order is not None: |
| for key in self.action_concat_order: |
| self.action_dims[key] = self.get_state_action_dims(key) |
| if self.state_concat_order is not None: |
| for key in self.state_concat_order: |
| self.state_dims[key] = self.get_state_action_dims(key) |
|
|