| | import abc |
| | from collections.abc import Sequence |
| | import dataclasses |
| | import enum |
| | import logging |
| | import pathlib |
| | from typing import Generic, TypeVar |
| |
|
| | import augmax |
| | from flax import nnx |
| | from flax import struct |
| | from flax import traverse_util |
| | import jax |
| | import jax.numpy as jnp |
| | import numpy as np |
| | import orbax.checkpoint as ocp |
| |
|
| | from openpi.shared import image_tools |
| | import openpi.shared.array_typing as at |
| |
|
| | logger = logging.getLogger("openpi") |
| |
|
| | ArrayT = TypeVar("ArrayT", at.Array, jax.ShapeDtypeStruct) |
| |
|
| |
|
| | class ModelType(enum.Enum): |
| | """Supported model types.""" |
| |
|
| | PI0 = "pi0" |
| | PI0_FAST = "pi0_fast" |
| |
|
| |
|
| | |
| | IMAGE_KEYS = ( |
| | "base_0_rgb", |
| | "left_wrist_0_rgb", |
| | "right_wrist_0_rgb", |
| | ) |
| |
|
| |
|
| | |
| | IMAGE_RESOLUTION = (224, 224) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | @at.typecheck |
| | @struct.dataclass |
| | class Observation(Generic[ArrayT]): |
| | """Holds observations, i.e., inputs to the model. |
| | |
| | See `Observation.from_dict` to see the expected dictionary form. This is the format |
| | that should be produced by the data transforms. |
| | """ |
| |
|
| | |
| | images: dict[str, at.Float[ArrayT, "*b h w c"]] |
| | |
| | image_masks: dict[str, at.Bool[ArrayT, "*b"]] |
| | |
| | state: at.Float[ArrayT, "*b s"] |
| |
|
| | |
| | tokenized_prompt: at.Int[ArrayT, "*b l"] | None = None |
| | |
| | tokenized_prompt_mask: at.Bool[ArrayT, "*b l"] | None = None |
| |
|
| | |
| |
|
| | |
| | token_ar_mask: at.Int[ArrayT, "*b l"] | None = None |
| | |
| | token_loss_mask: at.Bool[ArrayT, "*b l"] | None = None |
| |
|
| | @classmethod |
| | def from_dict(cls, data: at.PyTree[ArrayT]) -> "Observation[ArrayT]": |
| | """This method defines the mapping between unstructured data (i.e., nested dict) to the structured Observation format.""" |
| | |
| | if ("tokenized_prompt" in data) != ("tokenized_prompt_mask" in data): |
| | raise ValueError("tokenized_prompt and tokenized_prompt_mask must be provided together.") |
| | |
| | for key in data["image"]: |
| | if data["image"][key].dtype == np.uint8: |
| | data["image"][key] = data["image"][key].astype(np.float32) / 255.0 * 2.0 - 1.0 |
| | return cls( |
| | images=data["image"], |
| | image_masks=data["image_mask"], |
| | state=data["state"], |
| | tokenized_prompt=data.get("tokenized_prompt"), |
| | tokenized_prompt_mask=data.get("tokenized_prompt_mask"), |
| | token_ar_mask=data.get("token_ar_mask"), |
| | token_loss_mask=data.get("token_loss_mask"), |
| | ) |
| |
|
| | def to_dict(self) -> at.PyTree[ArrayT]: |
| | """Convert the Observation to a nested dict.""" |
| | result = dataclasses.asdict(self) |
| | result["image"] = result.pop("images") |
| | result["image_mask"] = result.pop("image_masks") |
| | return result |
| |
|
| |
|
| | |
| | |
| | Actions = at.Float[ArrayT, "*b ah ad"] |
| |
|
| |
|
| | def preprocess_observation( |
| | rng: at.KeyArrayLike | None, |
| | observation: Observation, |
| | *, |
| | train: bool = False, |
| | image_keys: Sequence[str] = IMAGE_KEYS, |
| | image_resolution: tuple[int, int] = IMAGE_RESOLUTION, |
| | ) -> Observation: |
| | """Preprocess the observations by performing image augmentations (if train=True), resizing (if necessary), and |
| | filling in a default image mask (if necessary). |
| | """ |
| |
|
| | if not set(image_keys).issubset(observation.images): |
| | raise ValueError(f"images dict missing keys: expected {image_keys}, got {list(observation.images)}") |
| |
|
| | batch_shape = observation.state.shape[:-1] |
| |
|
| | out_images = {} |
| | for key in image_keys: |
| | image = observation.images[key] |
| | if image.shape[1:3] != image_resolution: |
| | logger.info(f"Resizing image {key} from {image.shape[1:3]} to {image_resolution}") |
| | image = image_tools.resize_with_pad(image, *image_resolution) |
| |
|
| | if train: |
| | |
| | image = image / 2.0 + 0.5 |
| |
|
| | transforms = [] |
| | if "wrist" not in key: |
| | height, width = image.shape[1:3] |
| | transforms += [ |
| | augmax.RandomCrop(int(width * 0.95), int(height * 0.95)), |
| | augmax.Resize(width, height), |
| | augmax.Rotate((-5, 5)), |
| | ] |
| | transforms += [ |
| | augmax.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5), |
| | ] |
| | sub_rngs = jax.random.split(rng, image.shape[0]) |
| | image = jax.vmap(augmax.Chain(*transforms))(sub_rngs, image) |
| |
|
| | |
| | image = image * 2.0 - 1.0 |
| |
|
| | out_images[key] = image |
| |
|
| | |
| | out_masks = {} |
| | for key in out_images: |
| | if key not in observation.image_masks: |
| | |
| | out_masks[key] = jnp.ones(batch_shape, dtype=jnp.bool) |
| | else: |
| | out_masks[key] = jnp.asarray(observation.image_masks[key]) |
| |
|
| | return Observation( |
| | images=out_images, |
| | image_masks=out_masks, |
| | state=observation.state, |
| | tokenized_prompt=observation.tokenized_prompt, |
| | tokenized_prompt_mask=observation.tokenized_prompt_mask, |
| | token_ar_mask=observation.token_ar_mask, |
| | token_loss_mask=observation.token_loss_mask, |
| | ) |
| |
|
| |
|
| | @dataclasses.dataclass(frozen=True) |
| | class BaseModelConfig(abc.ABC): |
| | """Configuration shared by all models. Specific models should inherit from this class, and implement the `create` |
| | method to create the corresponding model. |
| | """ |
| |
|
| | |
| | action_dim: int |
| | |
| | action_horizon: int |
| | |
| | max_token_len: int |
| |
|
| | @property |
| | @abc.abstractmethod |
| | def model_type(self) -> ModelType: |
| | """The model type.""" |
| |
|
| | @abc.abstractmethod |
| | def create(self, rng: at.KeyArrayLike) -> "BaseModel": |
| | """Create a new model, initializing parameters.""" |
| |
|
| | def load(self, params: at.Params, *, remove_extra_params: bool = True) -> "BaseModel": |
| | """Create a model with the given parameters.""" |
| | model = nnx.eval_shape(self.create, jax.random.key(0)) |
| | graphdef, state = nnx.split(model) |
| | if remove_extra_params: |
| | params = ocp.transform_utils.intersect_trees(state.to_pure_dict(), params) |
| | at.check_pytree_equality(expected=state.to_pure_dict(), got=params, check_shapes=True, check_dtypes=False) |
| | state.replace_by_pure_dict(params) |
| | return nnx.merge(graphdef, state) |
| |
|
| | @abc.abstractmethod |
| | def inputs_spec(self, *, batch_size: int = 1) -> tuple[Observation, Actions]: |
| | """Returns the input specification for the model. Values are jax.ShapeDtypeStruct.""" |
| |
|
| | def fake_obs(self, batch_size: int = 1) -> Observation: |
| | observation_spec, _ = self.inputs_spec(batch_size=batch_size) |
| | return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), observation_spec) |
| |
|
| | def fake_act(self, batch_size: int = 1) -> Actions: |
| | _, action_spec = self.inputs_spec(batch_size=batch_size) |
| | return jax.tree.map(lambda x: jnp.ones(x.shape, x.dtype), action_spec) |
| |
|
| |
|
| | @dataclasses.dataclass |
| | class BaseModel(nnx.Module, abc.ABC): |
| | """Base class for all model implementations. Specific models should inherit from this class. They should call |
| | super().__init__() to initialize the shared attributes (action_dim, action_horizon, and max_token_len). |
| | """ |
| |
|
| | action_dim: int |
| | action_horizon: int |
| | max_token_len: int |
| |
|
| | @abc.abstractmethod |
| | def compute_loss( |
| | self, |
| | rng: at.KeyArrayLike, |
| | observation: Observation, |
| | actions: Actions, |
| | *, |
| | train: bool = False, |
| | ) -> at.Float[at.Array, "*b ah"]: ... |
| |
|
| | @abc.abstractmethod |
| | def sample_actions(self, rng: at.KeyArrayLike, observation: Observation) -> Actions: ... |
| |
|
| |
|
| | def restore_params( |
| | params_path: pathlib.Path | str, |
| | *, |
| | restore_type: type[np.ndarray] | type[jax.Array] = jax.Array, |
| | dtype: jnp.dtype | None = None, |
| | sharding: jax.sharding.Sharding | None = None, |
| | ) -> at.Params: |
| | """Restores unstructured params PyTree from a checkpoint. |
| | |
| | This works with checkpoints saved with `save_state` during openpi training (see `training/checkpoints.py`) as |
| | well as pre-trained checkpoints released for openpi. |
| | |
| | Args: |
| | params_path: The local path to the checkpoint directory. |
| | restore_type: The type to restore the params as. Can be set to `np.ndarray` to load the params as a numpy array. |
| | dtype: The dtype to restore all params as. If not provided, will use the original dtype from the checkpoint. |
| | sharding: The sharding to use for the params. If not provided, the params will be replicated across all devices. |
| | |
| | Returns: |
| | The restored params. |
| | """ |
| | params_path = pathlib.Path(params_path).resolve() |
| | if not params_path.exists(): |
| | raise FileNotFoundError(f"Model params not found at: {params_path}") |
| |
|
| | if restore_type is jax.Array and sharding is None: |
| | mesh = jax.sharding.Mesh(jax.devices(), ("x",)) |
| | sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) |
| |
|
| | with ocp.PyTreeCheckpointer() as ckptr: |
| | metadata = ckptr.metadata(params_path) |
| | item = {"params": metadata["params"]} |
| |
|
| | params = ckptr.restore( |
| | params_path, |
| | ocp.args.PyTreeRestore( |
| | item=item, |
| | restore_args=jax.tree.map( |
| | lambda _: ocp.ArrayRestoreArgs(sharding=sharding, restore_type=restore_type, dtype=dtype), item |
| | ), |
| | ), |
| | )["params"] |
| |
|
| | |
| | |
| | flat_params = traverse_util.flatten_dict(params) |
| | if all(kp[-1] == "value" for kp in flat_params): |
| | flat_params = {kp[:-1]: v for kp, v in flat_params.items()} |
| | return traverse_util.unflatten_dict(flat_params) |
| |
|