|
|
import collections |
|
|
import collections.abc |
|
|
import re |
|
|
import warnings |
|
|
from abc import abstractmethod |
|
|
from functools import cached_property |
|
|
from typing import Dict, List, Optional, Sequence, Tuple, TypeVar |
|
|
|
|
|
import numpy as np |
|
|
import PIL.Image |
|
|
import roma |
|
|
import torch |
|
|
import torchvision.transforms.v2 |
|
|
import transformers |
|
|
import yaml |
|
|
|
|
|
from .common_spear import ( |
|
|
Configurable, |
|
|
FlowInput, |
|
|
Normalization, |
|
|
ResizeMode, |
|
|
RoboticsControlPlan, |
|
|
RoboticsFlowInput, |
|
|
RoboticsInput, |
|
|
RoboticsOutput, |
|
|
RoboticsTarget, |
|
|
RotationFormat, |
|
|
expand_dims, |
|
|
is_quaternion, |
|
|
is_rotmat, |
|
|
is_rotmat_3x3, |
|
|
is_rotmat_9, |
|
|
quaternion_half_cover, |
|
|
rotmat_as_3x3, |
|
|
rotmat_as_9, |
|
|
) |
|
|
from .configuration_spear import ( |
|
|
ControlDataIOConfig, |
|
|
ImageSizeConfig, |
|
|
PaliGemmaProcessorConfig, |
|
|
) |
|
|
|
|
|
|
|
|
class VLMProcessor(Configurable): |
|
|
@abstractmethod |
|
|
def preprocess_inputs( |
|
|
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] |
|
|
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: ... |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
|
|
pass |
|
|
|
|
|
@property |
|
|
@abstractmethod |
|
|
def image_sizes(self) -> Dict[str, ImageSizeConfig]: |
|
|
pass |
|
|
|
|
|
|
|
|
class EmptyTokenizer(Configurable): |
|
|
""" |
|
|
Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the |
|
|
desired result. Includes the hidden states for the image tokens. |
|
|
""" |
|
|
|
|
|
def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None: |
|
|
super().__init__(config) |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
def __call__(self, *_) -> str: |
|
|
return "" |
|
|
|
|
|
|
|
|
def np_unique( |
|
|
data: np.ndarray, |
|
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
|
|
""" |
|
|
Compute unique elements in data and corresponding indices. |
|
|
|
|
|
np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply |
|
|
run np.unique on unsorted data, the indices you will get will be invalid. |
|
|
|
|
|
""" |
|
|
(_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True) |
|
|
(_, indices_of_first_occurence, inverse_indices, counts) = np.unique( |
|
|
indices[inverse], return_index=True, return_inverse=True, return_counts=True |
|
|
) |
|
|
unique_ids = data[indices_of_first_occurence] |
|
|
return unique_ids, indices_of_first_occurence, inverse_indices, counts |
|
|
|
|
|
|
|
|
def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
angles: Euler angles in radians in the format 'xyz', shape [..., 3] |
|
|
Returns: |
|
|
torch.Tensor of shape [..., 3, 3] containing rotation matrices |
|
|
""" |
|
|
return roma.euler_to_rotmat(convention="xyz", angles=angles, degrees=False) |
|
|
|
|
|
|
|
|
def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
angles: Euler angles in radians in the format 'xyz', shape [..., 3] |
|
|
Returns: |
|
|
torch.Tensor of shape [..., 4] containing unit quaternions |
|
|
""" |
|
|
return roma.euler_to_unitquat(convention="xyz", angles=angles, degrees=False, normalize=True) |
|
|
|
|
|
|
|
|
def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4] |
|
|
eps: Small constant to prevent division by zero |
|
|
Returns: |
|
|
torch.Tensor of shape [..., 4] of unit quaternions |
|
|
""" |
|
|
return quaternion / (quaternion.norm(dim=-1, keepdim=True).detach() + eps) |
|
|
|
|
|
|
|
|
def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized |
|
|
Returns: |
|
|
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) |
|
|
""" |
|
|
unit_quat = normalize_quaternion(quaternion) |
|
|
rotmat = roma.unitquat_to_euler(convention="xyz", quat=unit_quat, as_tuple=False, degrees=False) |
|
|
return rotmat |
|
|
|
|
|
|
|
|
def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized |
|
|
Returns: |
|
|
torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) |
|
|
""" |
|
|
unit_quat = normalize_quaternion(quaternion) |
|
|
rotmat = roma.unitquat_to_rotmat(unit_quat) |
|
|
return rotmat |
|
|
|
|
|
|
|
|
def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
rotmat: Batch of rotation matrices, shape [..., 3, 3] |
|
|
Returns: |
|
|
Batch of unit quaternions, shape [..., 4] |
|
|
""" |
|
|
rotmat = rotmat_as_3x3(rotmat) |
|
|
return roma.rotmat_to_unitquat(rotmat) |
|
|
|
|
|
|
|
|
def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
rotmat: Batch of rotation matrices, shape [..., 3, 3] |
|
|
Returns: |
|
|
Batch of Euler angles in radiant, shape [..., 3] |
|
|
""" |
|
|
rotmat = rotmat_as_3x3(rotmat) |
|
|
return roma.rotmat_to_euler(convention="xyz", rotmat=rotmat, as_tuple=False, degrees=False) |
|
|
|
|
|
|
|
|
def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Maps 9D input vectors onto SO(3) via symmetric orthogonalization. |
|
|
- Let SVD(M) = U \Sigma V^T |
|
|
- Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T |
|
|
- det(UV^T) ensures that det(SVD+(M)) = 1 |
|
|
- The return value is a rotation matrix (ortonormal) with the least-squares distance to M |
|
|
|
|
|
Args: |
|
|
x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3] |
|
|
Returns: |
|
|
torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3) |
|
|
""" |
|
|
with warnings.catch_warnings(): |
|
|
warnings.filterwarnings( |
|
|
"ignore", |
|
|
message="In CPU autocast, but the target dtype is not supported. Disabling autocast.", |
|
|
) |
|
|
with torch.autocast(device_type=x.device.type, dtype=torch.float32): |
|
|
matrices = x.view(-1, 3, 3) |
|
|
matrices = matrices.to(dtype=torch.float32) |
|
|
(u, s, v) = torch.svd(matrices) |
|
|
vt = torch.transpose(v, 1, 2) |
|
|
det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1) |
|
|
diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1) |
|
|
result = torch.matmul(u, diag_vt) |
|
|
result = result.view(*x.shape) |
|
|
result = result.to(dtype=x.dtype) |
|
|
return result |
|
|
|
|
|
|
|
|
def is_rotmat_orthonormal( |
|
|
rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = "none" |
|
|
) -> torch.Tensor | bool: |
|
|
""" |
|
|
Check if a rotation matrix is orthonormal or not. |
|
|
Args: |
|
|
rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9] |
|
|
epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally, |
|
|
anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not |
|
|
reduction: |
|
|
'none' - returns torch.Tensor of bools with the same batch shape |
|
|
'all' - returns a bool, True is ALL matrices in the batch are orthonormal |
|
|
Returns: |
|
|
torch.Tensor with the same batch shape or bool |
|
|
""" |
|
|
assert is_rotmat(rotmat) |
|
|
rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32)) |
|
|
is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon) |
|
|
if reduction == "none": |
|
|
return is_orthonormal |
|
|
if reduction == "all": |
|
|
return bool(torch.all(is_orthonormal).item()) |
|
|
raise ValueError(f"Unknown reduction mode {reduction}") |
|
|
|
|
|
|
|
|
def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool: |
|
|
""" |
|
|
Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3, |
|
|
also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles |
|
|
when accidentally `rotmat.shape[-2:] == [3, 3]` |
|
|
""" |
|
|
return ( |
|
|
is_rotmat_9(rotmat) |
|
|
or is_rotmat_3x3(rotmat) |
|
|
and is_rotmat_orthonormal(rotmat, epsilon=0.01, reduction="all") |
|
|
) |
|
|
|
|
|
|
|
|
def is_euler(euler: torch.Tensor) -> bool: |
|
|
return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler) |
|
|
|
|
|
|
|
|
def normalize_rotation(rotation: torch.Tensor) -> torch.Tensor: |
|
|
if is_quaternion(rotation): |
|
|
return normalize_quaternion(rotation) |
|
|
if is_euler(rotation): |
|
|
return rotation |
|
|
if is_rotmat(rotation): |
|
|
is_flat = is_rotmat_9(rotation) |
|
|
rotation = rotmat_as_3x3(rotation) if is_flat else rotation |
|
|
rotmat = roma.special_gramschmidt(rotation) |
|
|
rotmat = rotmat_as_9(rotmat) if is_flat else rotmat |
|
|
return rotmat |
|
|
raise ValueError(f"Unknown rotation format: {rotation.shape}") |
|
|
|
|
|
|
|
|
def rotation_format_from_tensor(rotation) -> RotationFormat: |
|
|
if is_quaternion(rotation): |
|
|
return RotationFormat.QUATERNION |
|
|
if is_orthonormal_rotmat(rotation): |
|
|
return RotationFormat.ROTMAT |
|
|
if is_euler(rotation): |
|
|
return RotationFormat.EULER |
|
|
raise ValueError(f"Tensor shape {rotation.shape} is not a valid rotation format") |
|
|
|
|
|
|
|
|
def is_unit_quaternion( |
|
|
quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = "none" |
|
|
) -> torch.Tensor | bool: |
|
|
""" |
|
|
Check if a quternion is normalized or not. |
|
|
Args: |
|
|
quaternion: torch.Tensor of shape [..., 4] |
|
|
tolerance: Tolerance for numerical comparisons |
|
|
reduction: |
|
|
'none' - returns torch.Tensor of bools with the same batch shape |
|
|
'all' - returns a bool, True if ALL quaternions in the batch are normalized |
|
|
Returns: |
|
|
torch.Tensor with the same batch shape or bool |
|
|
""" |
|
|
assert is_quaternion(quaternion) |
|
|
is_norm = torch.isclose( |
|
|
quaternion.norm(dim=-1, keepdim=True), |
|
|
torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device), |
|
|
atol=epsilon, |
|
|
) |
|
|
if reduction == "none": |
|
|
return is_norm |
|
|
if reduction == "all": |
|
|
return bool(torch.all(is_norm).item()) |
|
|
raise ValueError(f"Unknown reduction mode {reduction}") |
|
|
|
|
|
|
|
|
def convert_rotation( |
|
|
rotation: torch.Tensor | np.ndarray, |
|
|
output_format: RotationFormat, |
|
|
autonorm: bool = True, |
|
|
half_cover: bool = True, |
|
|
) -> torch.Tensor | np.ndarray: |
|
|
is_np = isinstance(rotation, np.ndarray) |
|
|
if is_np: |
|
|
rotation = torch.from_numpy(rotation) |
|
|
if is_quaternion(rotation): |
|
|
if autonorm and not is_unit_quaternion(rotation, reduction="all"): |
|
|
rotation = normalize_quaternion(rotation) |
|
|
if output_format == RotationFormat.QUATERNION: |
|
|
output = rotation |
|
|
elif output_format == RotationFormat.ROTMAT: |
|
|
output = rotmat_as_9(quaternion_to_rotmat(rotation)) |
|
|
elif output_format == RotationFormat.EULER: |
|
|
output = quaternion_to_euler(rotation) |
|
|
else: |
|
|
raise NotImplementedError(f"Unsupported rotation format: {output_format}") |
|
|
elif is_orthonormal_rotmat(rotation): |
|
|
if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction="all"): |
|
|
rotation = symmetric_orthogonalization(rotation) |
|
|
if output_format == RotationFormat.QUATERNION: |
|
|
output = rotmat_to_unit_quaternion(rotation) |
|
|
elif output_format == RotationFormat.ROTMAT: |
|
|
output = rotmat_as_9(rotation) |
|
|
elif output_format == RotationFormat.EULER: |
|
|
output = rotmat_to_euler(rotation) |
|
|
else: |
|
|
raise NotImplementedError(f"Unsupported rotation format: {output_format}") |
|
|
elif is_euler(rotation): |
|
|
if output_format == RotationFormat.QUATERNION: |
|
|
output = euler_to_unit_quaternion(rotation) |
|
|
elif output_format == RotationFormat.ROTMAT: |
|
|
output = rotmat_as_9(euler_to_rotmat(rotation)) |
|
|
elif output_format == RotationFormat.EULER: |
|
|
output = rotation |
|
|
else: |
|
|
raise NotImplementedError(f"Unsupported rotation format: {output_format}") |
|
|
else: |
|
|
raise ValueError(f"Unknown rotation encoding with shape {rotation.shape}") |
|
|
if output_format == RotationFormat.QUATERNION and half_cover: |
|
|
output = quaternion_half_cover(output) |
|
|
if is_np: |
|
|
output = output.numpy() |
|
|
return output |
|
|
|
|
|
|
|
|
def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the |
|
|
sequence to the 0-th element preceding the sequence |
|
|
|
|
|
Ex: |
|
|
`rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame, |
|
|
implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame |
|
|
Output: R_01, R_02, R_03, R_04 |
|
|
|
|
|
Args: |
|
|
rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing |
|
|
either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions |
|
|
Returns: |
|
|
torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations |
|
|
(R_01, R_02, R_03, R_04, ...) |
|
|
|
|
|
TODO: Can you make it work without for loop |
|
|
""" |
|
|
assert rotation_sequence.ndim >= 3, rotation_sequence.shape |
|
|
rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) |
|
|
rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION) |
|
|
batch_dims = np.arange(rotation_sequence.ndim - 2) |
|
|
delta_rotations = torch.cat( |
|
|
[rotation_sequence[..., :1, :]] |
|
|
+ [ |
|
|
roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2)) |
|
|
for i in range(2, rotation_sequence.shape[-2] + 1) |
|
|
], |
|
|
dim=-2, |
|
|
) |
|
|
delta_rotations = convert_rotation(delta_rotations, rotation_format) |
|
|
return delta_rotations |
|
|
|
|
|
|
|
|
def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray: |
|
|
"""Make sure image is of type np.ndarray and HWC format""" |
|
|
if isinstance(image, PIL.Image.Image): |
|
|
image = np.asarray(image) |
|
|
assert isinstance(image, np.ndarray), type(image) |
|
|
assert image.ndim in [2, 3], image.shape |
|
|
if image.ndim == 3: |
|
|
assert image.shape[-1] <= 4, image.shape |
|
|
return image |
|
|
|
|
|
|
|
|
def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]: |
|
|
if isinstance(image, np.ndarray): |
|
|
(height, width) = image.shape[:2] |
|
|
else: |
|
|
(width, height) = image.size |
|
|
return height, width |
|
|
|
|
|
|
|
|
def pad_image( |
|
|
image: PIL.Image.Image | np.ndarray, |
|
|
target_size: dict[str, int], |
|
|
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
|
|
) -> PIL.Image.Image | np.ndarray: |
|
|
"""Pad image adding a symmetric border around the height/width.""" |
|
|
assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) |
|
|
(height, width) = hw_from_image(image) |
|
|
(target_width, target_height) = (target_size["width"], target_size["height"]) |
|
|
if width == target_width and height == target_height: |
|
|
return image |
|
|
assert target_width >= width, f"Can't pad image of width {width} to {target_width}" |
|
|
assert target_height >= height, f"Can't pad image of height {height} to {target_height}" |
|
|
(horizontal_pad, vertical_pad) = ( |
|
|
int((target_width - width) / 2), |
|
|
int((target_height - height) / 2), |
|
|
) |
|
|
if isinstance(image, np.ndarray): |
|
|
padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * ( |
|
|
image.ndim - 2 |
|
|
) |
|
|
image = np.pad(image, padding, mode="constant", constant_values=pad_value) |
|
|
else: |
|
|
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) |
|
|
image = torchvision.transforms.v2.functional.pad( |
|
|
image, padding=padding, fill=pad_value, padding_mode="constant" |
|
|
) |
|
|
return image |
|
|
|
|
|
|
|
|
def pad_image_to_ratio( |
|
|
image: PIL.Image.Image | np.ndarray, |
|
|
target_wh_ratio: float, |
|
|
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
|
|
) -> PIL.Image.Image | np.ndarray: |
|
|
"""Pad image to a target aspect ratio.""" |
|
|
(height, width) = hw_from_image(image) |
|
|
wh_ratio = width / height |
|
|
if target_wh_ratio >= wh_ratio: |
|
|
pad_size = {"width": round(height * target_wh_ratio), "height": height} |
|
|
else: |
|
|
pad_size = {"width": width, "height": round(width / target_wh_ratio)} |
|
|
image = pad_image(image, target_size=pad_size, pad_value=pad_value) |
|
|
return image |
|
|
|
|
|
|
|
|
def crop_image( |
|
|
image: np.ndarray | PIL.Image.Image, |
|
|
start_height: int, |
|
|
start_width: int, |
|
|
target_height: int, |
|
|
target_width: int, |
|
|
) -> np.ndarray | PIL.Image.Image: |
|
|
np_image = assert_np_hwc_or_hw_image(image) |
|
|
(height, width) = hw_from_image(image) |
|
|
assert target_width <= width, f"Can't crop image of width {width} to {target_width}" |
|
|
assert target_height <= height, f"Can't crop image of width {height} to {target_height}" |
|
|
(start_height, start_width) = (round(start_height), round(start_width)) |
|
|
(target_height, target_width) = (round(target_height), round(target_width)) |
|
|
np_image = np_image[ |
|
|
start_height : start_height + target_height, |
|
|
start_width : start_width + target_width, |
|
|
..., |
|
|
] |
|
|
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image |
|
|
return image |
|
|
|
|
|
|
|
|
def crop_image_center( |
|
|
image: np.ndarray | PIL.Image.Image, target_size: dict[str, int] |
|
|
) -> np.ndarray | PIL.Image.Image: |
|
|
np_image = assert_np_hwc_or_hw_image(image) |
|
|
(height, width) = np_image.shape[:2] |
|
|
(target_height, target_width) = (target_size["height"], target_size["width"]) |
|
|
assert target_width <= width, f"Can't crop image of width {width} to {target_width}" |
|
|
assert target_height <= height, f"Can't crop image of width {height} to {target_height}" |
|
|
top = (height - target_height) // 2 |
|
|
left = (width - target_width) // 2 |
|
|
np_image = crop_image(np_image, top, left, target_height, target_width) |
|
|
image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image |
|
|
return image |
|
|
|
|
|
|
|
|
def crop_image_to_ratio( |
|
|
image: PIL.Image.Image | np.ndarray, target_wh_ratio: float |
|
|
) -> PIL.Image.Image | np.ndarray: |
|
|
"""Pad image to a target aspect ratio.""" |
|
|
(height, width) = hw_from_image(image) |
|
|
wh_ratio = width / height |
|
|
if target_wh_ratio >= wh_ratio: |
|
|
crop_size = {"width": width, "height": round(width / target_wh_ratio)} |
|
|
else: |
|
|
crop_size = {"width": round(height * target_wh_ratio), "height": height} |
|
|
image = crop_image_center(image, target_size=crop_size) |
|
|
return image |
|
|
|
|
|
|
|
|
def crop_and_pad_image_to_ratio( |
|
|
image: PIL.Image.Image | np.ndarray, |
|
|
target_wh_ratio: float, |
|
|
mode: ResizeMode | str, |
|
|
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
|
|
) -> PIL.Image.Image | np.ndarray: |
|
|
""" |
|
|
Crop and pad an image to a target size depending on the mode. |
|
|
It's expected that the source image and target size have different aspect ratios. |
|
|
|
|
|
Args: |
|
|
image: The image to crop and pad. |
|
|
target_size: The target size to crop and pad the image to. |
|
|
mode: The mode to use for cropping and padding. |
|
|
""" |
|
|
(height, width) = hw_from_image(image) |
|
|
wh_ratio = width / height |
|
|
if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001): |
|
|
return image |
|
|
if mode == ResizeMode.SMART: |
|
|
aspect_ratio = max(width, height) / min(width, height) |
|
|
target_ratio = max(target_wh_ratio, 1 / target_wh_ratio) |
|
|
if aspect_ratio == 1: |
|
|
if target_ratio >= 4 / 3 - 0.01: |
|
|
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 |
|
|
image = crop_image_to_ratio(image, crop_wh_ratio) |
|
|
else: |
|
|
pass |
|
|
elif aspect_ratio <= 4 / 3 + 0.01: |
|
|
if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): |
|
|
image = crop_image_to_ratio(image, 1.0) |
|
|
elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): |
|
|
image = crop_image_to_ratio(image, 1.0) |
|
|
elif target_ratio >= 4 / 3 + 0.01: |
|
|
pass |
|
|
else: |
|
|
crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 |
|
|
image = crop_image_to_ratio(image, crop_wh_ratio) |
|
|
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) |
|
|
elif mode == ResizeMode.PAD: |
|
|
image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) |
|
|
elif mode == ResizeMode.CROP: |
|
|
image = crop_image_to_ratio(image, target_wh_ratio) |
|
|
else: |
|
|
raise ValueError(f"Mode {mode} not supported") |
|
|
return image |
|
|
|
|
|
|
|
|
def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool: |
|
|
if isinstance(image, PIL.Image.Image): |
|
|
return image.mode in [ |
|
|
"1", |
|
|
"L", |
|
|
"LA", |
|
|
"La", |
|
|
"P", |
|
|
"PA", |
|
|
"F", |
|
|
"I", |
|
|
"I;16", |
|
|
"I;16L", |
|
|
"I;16B", |
|
|
"I;16N", |
|
|
] |
|
|
if isinstance(image, np.ndarray): |
|
|
return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1 |
|
|
raise ValueError(f"Unsupported image type: {type(image)}") |
|
|
|
|
|
|
|
|
def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool: |
|
|
image = np.asarray(image) |
|
|
return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1 |
|
|
|
|
|
|
|
|
def resize_image( |
|
|
image: PIL.Image.Image | np.ndarray, |
|
|
target_size: dict[str, int], |
|
|
mode: ResizeMode | str, |
|
|
resample: PIL.Image.Resampling | str = "auto", |
|
|
pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, |
|
|
) -> PIL.Image.Image | np.ndarray: |
|
|
(target_width, target_height) = (target_size["width"], target_size["height"]) |
|
|
(height, width) = hw_from_image(image) |
|
|
if height == target_height and width == target_width: |
|
|
return image |
|
|
if resample == "auto": |
|
|
if is_single_channel_image(image): |
|
|
resample = PIL.Image.Resampling.BILINEAR |
|
|
else: |
|
|
resample = PIL.Image.Resampling.LANCZOS |
|
|
else: |
|
|
assert isinstance(resample, PIL.Image.Resampling), resample |
|
|
if is_single_channel_image(image) and resample not in [ |
|
|
PIL.Image.Resampling.BILINEAR, |
|
|
PIL.Image.Resampling.BICUBIC, |
|
|
]: |
|
|
raise ValueError( |
|
|
f"Single channel images must be resized with bilinear or bicubic, but got {resample}" |
|
|
) |
|
|
if is_bin_mask := is_binary_mask(image): |
|
|
image = np.asarray(image).astype(np.uint8) * 255 |
|
|
if mode == ResizeMode.SMART: |
|
|
image = crop_and_pad_image_to_ratio( |
|
|
image, |
|
|
target_wh_ratio=target_width / target_height, |
|
|
mode=mode, |
|
|
pad_value=pad_value, |
|
|
) |
|
|
pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image |
|
|
if mode in [ResizeMode.NAIVE, ResizeMode.SMART]: |
|
|
pil_image = pil_image.resize((target_width, target_height), resample=resample) |
|
|
else: |
|
|
raise NotImplementedError(f"Mode {mode} not supported") |
|
|
image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image |
|
|
if is_bin_mask: |
|
|
image = image.astype(np.uint8) > 127 |
|
|
return image |
|
|
|
|
|
|
|
|
def is_global_norm( |
|
|
norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list], |
|
|
) -> bool: |
|
|
"""Return true if norm is NONE or global for all datasets""" |
|
|
return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping) |
|
|
|
|
|
|
|
|
def is_mean_norm( |
|
|
norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list], |
|
|
) -> bool: |
|
|
"""Return true if norm is based on mean and std""" |
|
|
return ( |
|
|
norm == Normalization.MEAN |
|
|
or isinstance(norm, collections.abc.Mapping) |
|
|
and set(norm.keys()) == {"mean", "std"} |
|
|
) |
|
|
|
|
|
|
|
|
def _broadcast_shapes( |
|
|
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Broadcast shapes for normalization: |
|
|
Args: |
|
|
value: torch.Tensor of shape [..., num_components]. The entire shape might be: |
|
|
- [num_components]: `value` has no batch dimension |
|
|
- [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds |
|
|
contained in `low` and `high` |
|
|
- [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds |
|
|
contained in `low` and `high` |
|
|
- [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high` |
|
|
must be for a single dataset, i.e. `num_datasets = 1` |
|
|
|
|
|
low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low` |
|
|
contains normalization bounds for a single dataset |
|
|
high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high` |
|
|
contains normalization bounds for a single dataset |
|
|
Returns: |
|
|
Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value` |
|
|
""" |
|
|
assert low.ndim == high.ndim == 2, f"{low.shape} != {high.shape} or ndim != 2" |
|
|
assert value.shape[-1] == low.shape[-1] == high.shape[-1], f"{value.shape} != {low.shape} / {high.shape}" |
|
|
if value.ndim == low.ndim == high.ndim: |
|
|
return low, high |
|
|
if value.ndim < low.ndim: |
|
|
assert low.ndim == high.ndim == 2, f"{low.shape}, {high.shape}" |
|
|
assert low.shape[0] == high.shape[0] == 1, f"{low.shape}, {high.shape}" |
|
|
(low, high) = (low.view(-1), high.view(-1)) |
|
|
return low, high |
|
|
if low.shape[0] == high.shape[0] == 1: |
|
|
low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1]) |
|
|
high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1]) |
|
|
else: |
|
|
assert value.shape[0] == low.shape[0] == high.shape[0], f"{value.shape} != {low.shape} / {high.shape}" |
|
|
low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1]) |
|
|
high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1]) |
|
|
return low, high |
|
|
|
|
|
|
|
|
def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |
|
|
(mean, std) = _broadcast_shapes(value, mean, std) |
|
|
(mean, std) = (mean.to(device=value.device), std.to(device=value.device)) |
|
|
return value * (std + 1e-08) + mean |
|
|
|
|
|
|
|
|
def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: |
|
|
(low, high) = _broadcast_shapes(value, low, high) |
|
|
(low, high) = (low.to(device=value.device), high.to(device=value.device)) |
|
|
return 0.5 * (value + 1) * (high - low) + low |
|
|
|
|
|
|
|
|
def normalize_gripper_by_bounds( |
|
|
value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
If binary, normalize to [0, 1], otherwise normalize to [-1, 1] |
|
|
""" |
|
|
(low, high) = _broadcast_shapes(value, low, high) |
|
|
(low, high) = (low.to(device=value.device), high.to(device=value.device)) |
|
|
if binary: |
|
|
return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0) |
|
|
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) |
|
|
|
|
|
|
|
|
def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: |
|
|
(mean, std) = _broadcast_shapes(value, mean, std) |
|
|
(mean, std) = (mean.to(device=value.device), std.to(device=value.device)) |
|
|
return (value - mean) / (std + 1e-08) |
|
|
|
|
|
|
|
|
def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: |
|
|
(low, high) = _broadcast_shapes(value, low, high) |
|
|
(low, high) = (low.to(device=value.device), high.to(device=value.device)) |
|
|
return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) |
|
|
|
|
|
|
|
|
def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray: |
|
|
if low < 0.0: |
|
|
return np.clip(-gripper, low, high) |
|
|
return high - np.clip(gripper, low, high) |
|
|
|
|
|
|
|
|
GRIPPER_BOUNDS = { |
|
|
"bridge": (0.0, 1.0), |
|
|
"bridge_orig": (0.0, 1.0), |
|
|
"droid": (0.0, 1.0), |
|
|
"roboset": (0.0, 1.0), |
|
|
} |
|
|
|
|
|
|
|
|
def preprocess_gripper_observation( |
|
|
gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True |
|
|
) -> np.ndarray: |
|
|
""" |
|
|
Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset |
|
|
or from the robot and output is normalized continuous value. |
|
|
- if `binary`, output is in [0, 1], with 0 = closed and 1 = open. |
|
|
- otherwise, output is in [-1, 1], with -1 = closed and 1 = open. |
|
|
|
|
|
Dataset-specific gripper observations: |
|
|
bridge: continuous; ~[0=closed; 1=open] |
|
|
bridge_orig: continuous; ~[0=closed; 1=open] |
|
|
droid: continuous; [0=open, 1=closed] |
|
|
roboset: continuous; [0=open, 1=closed] |
|
|
""" |
|
|
if isinstance(dataset_name, np.ndarray): |
|
|
assert np.unique(dataset_name).size == 1, dataset_name |
|
|
dataset_name = str(dataset_name[0]) |
|
|
if dataset_name in [ |
|
|
"droid", |
|
|
"roboset", |
|
|
]: |
|
|
(low, high) = GRIPPER_BOUNDS[dataset_name] |
|
|
gripper = normalize_gripper_by_bounds( |
|
|
torch.from_numpy(invert_gripper(gripper, low=low, high=high)), |
|
|
low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), |
|
|
high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), |
|
|
binary=binary, |
|
|
).numpy() |
|
|
elif dataset_name in [ |
|
|
"bridge", |
|
|
"bridge_orig", |
|
|
]: |
|
|
(low, high) = GRIPPER_BOUNDS[dataset_name] |
|
|
gripper = normalize_gripper_by_bounds( |
|
|
torch.from_numpy(gripper), |
|
|
low=torch.full(gripper.shape, low, dtype=torch.float32), |
|
|
high=torch.full(gripper.shape, high, dtype=torch.float32), |
|
|
binary=binary, |
|
|
).numpy() |
|
|
else: |
|
|
raise NotImplementedError(f"Unknown dataset: {dataset_name}") |
|
|
return gripper |
|
|
|
|
|
|
|
|
def rotation_norm_bounds( |
|
|
rotation_norm: Normalization, |
|
|
rotation_format: RotationFormat, |
|
|
stats: Dict[str, Dict[str, Dict[str, List[float]]]], |
|
|
dataset_names: List[str], |
|
|
) -> Dict[str, Dict[str, torch.Tensor]]: |
|
|
if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE: |
|
|
if rotation_norm == Normalization.BOUNDS: |
|
|
results = { |
|
|
dataset_name: { |
|
|
"low": torch.tensor(dataset_stats["euler"]["min"]), |
|
|
"high": torch.tensor(dataset_stats["euler"]["max"]), |
|
|
} |
|
|
for (dataset_name, dataset_stats) in stats.items() |
|
|
} |
|
|
elif rotation_norm == Normalization.BOUNDS_Q99: |
|
|
results = { |
|
|
dataset_name: { |
|
|
"low": torch.tensor(dataset_stats["euler"]["q01"]), |
|
|
"high": torch.tensor(dataset_stats["euler"]["q99"]), |
|
|
} |
|
|
for (dataset_name, dataset_stats) in stats.items() |
|
|
} |
|
|
else: |
|
|
raise NotImplementedError(f"Normalization type {rotation_norm} not yet implemented") |
|
|
else: |
|
|
assert rotation_norm == Normalization.NONE, rotation_norm |
|
|
if rotation_format == RotationFormat.EULER: |
|
|
rotation_size = 3 |
|
|
elif rotation_format == RotationFormat.QUATERNION: |
|
|
rotation_size = 4 |
|
|
else: |
|
|
rotation_size = 9 |
|
|
results = { |
|
|
dataset_name: { |
|
|
"low": -1 * torch.ones(rotation_size, dtype=torch.float32), |
|
|
"high": 1 * torch.ones(rotation_size, dtype=torch.float32), |
|
|
} |
|
|
for dataset_name in dataset_names |
|
|
} |
|
|
return results |
|
|
|
|
|
|
|
|
def translation_norm_bounds( |
|
|
translation_norm: Normalization | tuple, |
|
|
stats: Dict[str, Dict[str, Dict[str, List[float]]]], |
|
|
dataset_names: List[str], |
|
|
) -> Dict[str, Dict[str, torch.Tensor]]: |
|
|
if isinstance(translation_norm, (Normalization, str)) and translation_norm != Normalization.NONE: |
|
|
if translation_norm == Normalization.BOUNDS: |
|
|
results = { |
|
|
dataset_name: { |
|
|
"low": torch.tensor(dataset_stats["translation"]["min"]), |
|
|
"high": torch.tensor(dataset_stats["translation"]["max"]), |
|
|
} |
|
|
for (dataset_name, dataset_stats) in stats.items() |
|
|
} |
|
|
elif translation_norm == Normalization.BOUNDS_Q99: |
|
|
results = { |
|
|
dataset_name: { |
|
|
"low": torch.tensor(dataset_stats["translation"]["q01"]), |
|
|
"high": torch.tensor(dataset_stats["translation"]["q99"]), |
|
|
} |
|
|
for (dataset_name, dataset_stats) in stats.items() |
|
|
} |
|
|
elif translation_norm == Normalization.MEAN: |
|
|
results = { |
|
|
dataset_name: { |
|
|
"mean": torch.tensor(dataset_stats["translation"]["mean"]), |
|
|
"std": torch.tensor(dataset_stats["translation"]["std"]), |
|
|
} |
|
|
for (dataset_name, dataset_stats) in stats.items() |
|
|
} |
|
|
else: |
|
|
raise NotImplementedError(f"Normalization type {translation_norm} not yet implemented") |
|
|
elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE: |
|
|
results = { |
|
|
dataset_name: { |
|
|
"low": -1 * torch.ones(3, dtype=torch.float32), |
|
|
"high": 1 * torch.ones(3, dtype=torch.float32), |
|
|
} |
|
|
for dataset_name in dataset_names |
|
|
} |
|
|
else: |
|
|
assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm) |
|
|
assert all((len(value) == 3 for value in translation_norm.values())), translation_norm |
|
|
assert set(translation_norm.keys()) in ( |
|
|
{"low", "high"}, |
|
|
{"mean", "std"}, |
|
|
), translation_norm |
|
|
results = { |
|
|
dataset_name: { |
|
|
key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items() |
|
|
} |
|
|
for dataset_name in dataset_names |
|
|
} |
|
|
return results |
|
|
|
|
|
|
|
|
VLAMProcessorConfigT = TypeVar("VLAMProcessorConfigT") |
|
|
|
|
|
|
|
|
class VLAMProcessor(Configurable): |
|
|
def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor): |
|
|
super().__init__(config) |
|
|
self.vlm_processor = vlm_processor |
|
|
self.control_tokenizer = EmptyTokenizer( |
|
|
config=self.config.control_tokenizer_config, tokenizer=self.tokenizer |
|
|
) |
|
|
self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { |
|
|
"obs_translation": self.obs_translation_norm_bounds, |
|
|
"obs_rotation": self.obs_rotation_norm_bounds, |
|
|
"translation": self.translation_norm_bounds, |
|
|
"rotation": self.rotation_norm_bounds, |
|
|
"joints": self.joints_norm_bounds, |
|
|
} |
|
|
|
|
|
@property |
|
|
def tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
|
|
return self.vlm_processor.tokenizer |
|
|
|
|
|
@property |
|
|
def image_sizes(self) -> Dict[str, ImageSizeConfig]: |
|
|
return self.vlm_processor.image_sizes |
|
|
|
|
|
@property |
|
|
def camera_names(self) -> List[str]: |
|
|
return list(self.vlm_processor.image_sizes.keys()) |
|
|
|
|
|
@property |
|
|
def control_io_config(self) -> ControlDataIOConfig: |
|
|
return self.config.control_io_config |
|
|
|
|
|
@cached_property |
|
|
def rotation_components(self) -> int: |
|
|
if self.config.rotation_format == RotationFormat.EULER: |
|
|
return 3 |
|
|
if self.config.rotation_format == RotationFormat.QUATERNION: |
|
|
return 4 |
|
|
if self.config.rotation_format == RotationFormat.ROTMAT: |
|
|
return 9 |
|
|
raise NotImplementedError(self.config.rotation_format) |
|
|
|
|
|
@abstractmethod |
|
|
def policy_control_plan_from_model_target( |
|
|
self, target: RoboticsTarget, dataset_name: np.ndarray |
|
|
) -> RoboticsControlPlan: |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def policy_control_plan_from_model_output( |
|
|
self, |
|
|
model_output: RoboticsOutput, |
|
|
dataset_name: np.ndarray, |
|
|
valid_mask: torch.Tensor, |
|
|
) -> RoboticsControlPlan: |
|
|
pass |
|
|
|
|
|
def resize_image( |
|
|
self, camera_name: str, image: PIL.Image.Image | np.ndarray |
|
|
) -> PIL.Image.Image | np.ndarray: |
|
|
return resize_image( |
|
|
image, |
|
|
target_size={ |
|
|
"width": self.image_sizes[camera_name].width, |
|
|
"height": self.image_sizes[camera_name].height, |
|
|
}, |
|
|
mode=self.config.image_resize, |
|
|
resample=PIL.Image.Resampling.LANCZOS, |
|
|
) |
|
|
|
|
|
def preprocess_inputs( |
|
|
self, |
|
|
chat: List[str], |
|
|
images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]], |
|
|
ee_pose_translation: np.ndarray, |
|
|
ee_pose_rotation: np.ndarray, |
|
|
gripper: np.ndarray, |
|
|
joints: np.ndarray, |
|
|
dataset_name: np.ndarray, |
|
|
inference_mode: bool, |
|
|
control_target: Optional[RoboticsTarget] = None, |
|
|
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Preprocess the inputs for a single example |
|
|
Args: |
|
|
instruction: Language instruction |
|
|
images: History of input images with increasing timestamps |
|
|
ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3] |
|
|
ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9] |
|
|
joints: np.ndarray, shape [..., num_past_scalars, <= 7] |
|
|
dataset_name: 1D np.ndarray |
|
|
inference_mode: If True, prepare the input for inference (e.g. don't include target |
|
|
any tokens in the input if relevant). If control_target is available, it should |
|
|
still be preprocessed for test dataset comparison |
|
|
control_target: RoboticsTarget, each component of shape |
|
|
[..., num_control_steps, num_control_components]. Provided only when available, usually |
|
|
during training and dataset test |
|
|
Returns: |
|
|
Dict containing torch.Tensor with inputs |
|
|
""" |
|
|
del control_target |
|
|
del inference_mode |
|
|
inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) |
|
|
images: Dict[str, torch.Tensor] = inputs["images"] |
|
|
input_ids: torch.Tensor = inputs["input_ids"][..., : self.tokenizer.model_max_length] |
|
|
target_text_tokens_ids: torch.Tensor = inputs["target_ids"][..., : self.tokenizer.model_max_length] |
|
|
attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) |
|
|
ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) |
|
|
ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) |
|
|
ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) |
|
|
gripper = preprocess_gripper_observation(gripper, dataset_name) |
|
|
gripper = torch.tensor(gripper, dtype=torch.float32) |
|
|
ee_pose_translation = self.normalize( |
|
|
ee_pose_translation, dataset_name=dataset_name, key="obs_translation" |
|
|
) |
|
|
ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key="obs_rotation") |
|
|
joints = torch.tensor(joints, dtype=torch.float32) |
|
|
if joints.shape[-1] < 7: |
|
|
missing_size = 7 - joints.shape[-1] |
|
|
joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) |
|
|
joints = self.normalize(joints, dataset_name=dataset_name, key="joints") |
|
|
outputs = { |
|
|
"images": images, |
|
|
"input_ids": input_ids, |
|
|
"target_text_tokens_ids": target_text_tokens_ids, |
|
|
"attn_mask": attn_mask, |
|
|
"ee_pose_translation": ee_pose_translation, |
|
|
"ee_pose_rotation": ee_pose_rotation, |
|
|
"gripper": gripper, |
|
|
"joints": joints, |
|
|
"control_tokens_ids": None, |
|
|
"target_control_tokens_ids": None, |
|
|
} |
|
|
return outputs |
|
|
|
|
|
def create_input( |
|
|
self, |
|
|
chat: List[str], |
|
|
images: Dict[str, List[PIL.Image.Image]], |
|
|
ee_pose_translation: np.ndarray, |
|
|
ee_pose_rotation: np.ndarray, |
|
|
gripper: np.ndarray, |
|
|
joints: np.ndarray, |
|
|
dataset_name: np.ndarray, |
|
|
inference_mode: bool, |
|
|
control_target: Optional[RoboticsTarget] = None, |
|
|
) -> RoboticsInput: |
|
|
inputs = self.preprocess_inputs( |
|
|
chat=chat, |
|
|
images=images, |
|
|
ee_pose_translation=ee_pose_translation, |
|
|
ee_pose_rotation=ee_pose_rotation, |
|
|
gripper=gripper, |
|
|
joints=joints, |
|
|
dataset_name=dataset_name, |
|
|
inference_mode=inference_mode, |
|
|
control_target=control_target, |
|
|
) |
|
|
inputs.pop("target_text_tokens_ids") |
|
|
inputs.pop("target_control_tokens_ids") |
|
|
return RoboticsInput(**inputs) |
|
|
|
|
|
def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: |
|
|
if is_mean_norm(getattr(self.config, f"{key}_norm")): |
|
|
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
|
|
output = normalize_by_moments(value, mean=mean, std=std) |
|
|
else: |
|
|
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
|
|
output = normalize_by_bounds(value, low=low, high=high) |
|
|
return output |
|
|
|
|
|
def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: |
|
|
if is_mean_norm(getattr(self.config, f"{key}_norm")): |
|
|
(mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
|
|
output = unnormalize_by_moments(value, mean=mean, std=std) |
|
|
else: |
|
|
(low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) |
|
|
output = unnormalize_by_bounds(value, low=low, high=high) |
|
|
return output |
|
|
|
|
|
def _norm_bounds_from_dataset_name( |
|
|
self, dataset_name: np.ndarray, component_key: str |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Create an array of normalization bounds corresponding to dataset names |
|
|
Args: |
|
|
dataset_name: Array of shape [B] of dataset names for which to fetch the low and high |
|
|
normalization bounds. Note the values can be repeating |
|
|
component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to |
|
|
compute the normalization bounds |
|
|
Returns: |
|
|
Tuple of low and high bounds or norm and std, each of shape [B, -1] |
|
|
""" |
|
|
norm = getattr(self.config, f"{component_key}_norm") |
|
|
if is_mean_norm(norm): |
|
|
(stats_key_1, stats_key_2) = ("mean", "std") |
|
|
else: |
|
|
(stats_key_1, stats_key_2) = ("low", "high") |
|
|
if component_key == "joints": |
|
|
if not isinstance(norm, collections.abc.Mapping): |
|
|
raise NotImplementedError() |
|
|
stats = { |
|
|
key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1])) |
|
|
for (key, value) in self.joints_norm_bounds["ANY"].items() |
|
|
} |
|
|
return tuple(stats.values()) |
|
|
component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1] |
|
|
if self.dataset_names == ["ANY"]: |
|
|
stats_1 = self.norm_bounds[component_key]["ANY"][stats_key_1] |
|
|
stats_2 = self.norm_bounds[component_key]["ANY"][stats_key_2] |
|
|
stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0) |
|
|
stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0) |
|
|
else: |
|
|
(unique_names, _, inverse_indices, _) = np_unique(dataset_name) |
|
|
stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32) |
|
|
stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32) |
|
|
for i, ds_name in enumerate(unique_names): |
|
|
stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1].numpy() |
|
|
stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2].numpy() |
|
|
stats_1 = stats_1[inverse_indices] |
|
|
stats_2 = stats_2[inverse_indices] |
|
|
return torch.from_numpy(stats_1), torch.from_numpy(stats_2) |
|
|
|
|
|
@cached_property |
|
|
def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
|
|
return rotation_norm_bounds( |
|
|
rotation_norm=self.config.obs_rotation_norm, |
|
|
rotation_format=self.config.rotation_format, |
|
|
stats=self._observation_stats, |
|
|
dataset_names=self.dataset_names, |
|
|
) |
|
|
|
|
|
@cached_property |
|
|
def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
|
|
return translation_norm_bounds( |
|
|
translation_norm=self.config.obs_translation_norm, |
|
|
stats=self._observation_stats, |
|
|
dataset_names=self.dataset_names, |
|
|
) |
|
|
|
|
|
@cached_property |
|
|
def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
|
|
return rotation_norm_bounds( |
|
|
rotation_norm=self.config.rotation_norm, |
|
|
rotation_format=self.config.rotation_format, |
|
|
stats=self._control_stats, |
|
|
dataset_names=self.dataset_names, |
|
|
) |
|
|
|
|
|
@cached_property |
|
|
def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
|
|
return translation_norm_bounds( |
|
|
translation_norm=self.config.translation_norm, |
|
|
stats=self._control_stats, |
|
|
dataset_names=self.dataset_names, |
|
|
) |
|
|
|
|
|
@cached_property |
|
|
def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
NOTE: |
|
|
- Joint values across all joints and all datasets vary in the range [-2pi; 2pi] |
|
|
- The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi] |
|
|
- It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint |
|
|
""" |
|
|
low = torch.tensor(self.config.joints_norm["low"], dtype=torch.float32) |
|
|
high = torch.tensor(self.config.joints_norm["high"], dtype=torch.float32) |
|
|
results = {"ANY": {"low": low, "high": high}} |
|
|
return results |
|
|
|
|
|
@cached_property |
|
|
def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: |
|
|
return { |
|
|
"bridge": { |
|
|
"euler": { |
|
|
"max": [3.141592653589793, 1.570796251296997, 3.141204357147217], |
|
|
"mean": [ |
|
|
-0.25754162314671525, |
|
|
-0.12370228389510128, |
|
|
0.1620053749182691, |
|
|
], |
|
|
"min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938], |
|
|
"q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896], |
|
|
"q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236], |
|
|
"std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], |
|
|
}, |
|
|
"gripper": { |
|
|
"max": [1.0370277166366577], |
|
|
"min": [0.04637829214334488], |
|
|
"q01": [0.05192930996417999], |
|
|
"q99": [1.0118417739868164], |
|
|
}, |
|
|
"joints": { |
|
|
"max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
}, |
|
|
"translation": { |
|
|
"max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], |
|
|
"mean": [ |
|
|
0.309032678604126, |
|
|
0.03403777256608009, |
|
|
0.061277542263269424, |
|
|
], |
|
|
"min": [ |
|
|
-0.04167502000927925, |
|
|
-0.2889411449432373, |
|
|
-0.13934996724128723, |
|
|
], |
|
|
"q01": [ |
|
|
0.1711955964565277, |
|
|
-0.15639324486255646, |
|
|
-0.048255354166030884, |
|
|
], |
|
|
"q99": [ |
|
|
0.4604376256465912, |
|
|
0.24112474918365479, |
|
|
0.18886254727840424, |
|
|
], |
|
|
"std": [ |
|
|
0.0635896623134613, |
|
|
0.09153717756271362, |
|
|
0.049334850162267685, |
|
|
], |
|
|
}, |
|
|
}, |
|
|
"bridge_orig": { |
|
|
"euler": { |
|
|
"max": [3.141592653589793, 1.570796251296997, 3.141204357147217], |
|
|
"mean": [ |
|
|
-0.25754162314671525, |
|
|
-0.12370228389510128, |
|
|
0.1620053749182691, |
|
|
], |
|
|
"min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938], |
|
|
"q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896], |
|
|
"q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236], |
|
|
"std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], |
|
|
}, |
|
|
"gripper": { |
|
|
"max": [1.0370277166366577], |
|
|
"min": [0.04637829214334488], |
|
|
"q01": [0.05192930996417999], |
|
|
"q99": [1.0118417739868164], |
|
|
}, |
|
|
"joints": { |
|
|
"max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
"std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], |
|
|
}, |
|
|
"translation": { |
|
|
"max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], |
|
|
"mean": [ |
|
|
0.309032678604126, |
|
|
0.03403777256608009, |
|
|
0.061277542263269424, |
|
|
], |
|
|
"min": [ |
|
|
-0.04167502000927925, |
|
|
-0.2889411449432373, |
|
|
-0.13934996724128723, |
|
|
], |
|
|
"q01": [ |
|
|
0.1711955964565277, |
|
|
-0.15639324486255646, |
|
|
-0.048255354166030884, |
|
|
], |
|
|
"q99": [ |
|
|
0.4604376256465912, |
|
|
0.24112474918365479, |
|
|
0.18886254727840424, |
|
|
], |
|
|
"std": [ |
|
|
0.0635896623134613, |
|
|
0.09153717756271362, |
|
|
0.049334850162267685, |
|
|
], |
|
|
}, |
|
|
}, |
|
|
"droid": { |
|
|
"euler": { |
|
|
"max": [3.141592502593994, 1.5705928802490234, 3.1415867805480957], |
|
|
"mean": [ |
|
|
0.3140628098409554, |
|
|
-0.09296274023036387, |
|
|
-0.07227215454779846, |
|
|
], |
|
|
"min": [ |
|
|
-3.141592502593994, |
|
|
-1.5691150426864624, |
|
|
-3.1415374279022217, |
|
|
], |
|
|
"q01": [ |
|
|
-3.1378602981567383, |
|
|
-1.2125312042236327, |
|
|
-2.1614069032669065, |
|
|
], |
|
|
"q99": [3.137854380607605, 0.9200375998020163, 1.9367506909370364], |
|
|
"std": [2.926265757944871, 0.363273475703332, 0.7576065217938824], |
|
|
}, |
|
|
"gripper": { |
|
|
"max": [1.0], |
|
|
"min": [0.0], |
|
|
"q01": [0.0], |
|
|
"q99": [0.9911894202232361], |
|
|
}, |
|
|
"joints": { |
|
|
"max": [ |
|
|
2.668445110321045, |
|
|
1.5691218376159668, |
|
|
2.666306734085083, |
|
|
-0.3114914000034332, |
|
|
2.6624162197113037, |
|
|
4.28157901763916, |
|
|
2.752457857131958, |
|
|
], |
|
|
"mean": [ |
|
|
0.023137084334640106, |
|
|
0.2704989977282293, |
|
|
-0.01451389357228282, |
|
|
-2.018709403792315, |
|
|
-0.042720520800030394, |
|
|
2.350281188152209, |
|
|
0.12424663946659845, |
|
|
], |
|
|
"min": [ |
|
|
-2.6536705493927, |
|
|
-1.547789216041565, |
|
|
-2.6781487464904785, |
|
|
-2.9409868717193604, |
|
|
-2.6705946922302246, |
|
|
0.24893812835216522, |
|
|
-2.7615714073181152, |
|
|
], |
|
|
"q01": [ |
|
|
-0.9026106441020965, |
|
|
-0.8547340619564057, |
|
|
-0.9028875434398651, |
|
|
-2.7698556280136106, |
|
|
-1.6851656341552732, |
|
|
1.2335169839859008, |
|
|
-1.9587260699272155, |
|
|
], |
|
|
"q99": [ |
|
|
0.9569852340221403, |
|
|
1.4148830294609054, |
|
|
0.7693877756595566, |
|
|
-0.4545914208889008, |
|
|
1.5623322343826267, |
|
|
3.475611729621887, |
|
|
2.263479118347167, |
|
|
], |
|
|
"std": [ |
|
|
0.31695080251469465, |
|
|
0.49522214687158767, |
|
|
0.27993538230553827, |
|
|
0.478161574676113, |
|
|
0.4969961591445458, |
|
|
0.45101008525403846, |
|
|
0.7287264344068457, |
|
|
], |
|
|
}, |
|
|
"translation": { |
|
|
"max": [0.8575563430786133, 0.799155592918396, 1.0043904781341553], |
|
|
"mean": [ |
|
|
0.5283099395864883, |
|
|
0.005363794653877434, |
|
|
0.3120132207021294, |
|
|
], |
|
|
"min": [ |
|
|
-0.15604186058044434, |
|
|
-0.827903687953949, |
|
|
-0.2347021996974945, |
|
|
], |
|
|
"q01": [ |
|
|
0.26669957995414734, |
|
|
-0.43774398624897004, |
|
|
-0.048167889714241026, |
|
|
], |
|
|
"q99": [0.7774086785316463, 0.428325751423835, 0.776091011762619], |
|
|
"std": [ |
|
|
0.1148424841779685, |
|
|
0.17489566608140428, |
|
|
0.16541062032731538, |
|
|
], |
|
|
}, |
|
|
}, |
|
|
"roboset": { |
|
|
"euler": { |
|
|
"max": [3.1415449294818236, 1.5705575529715636, 3.141527342124582], |
|
|
"mean": [ |
|
|
-0.0398455755412464, |
|
|
1.0518070390619125, |
|
|
-0.015345692503002759, |
|
|
], |
|
|
"min": [ |
|
|
-3.1415813300509536, |
|
|
-1.5222832468962035, |
|
|
-3.141575300866071, |
|
|
], |
|
|
"q01": [ |
|
|
-2.9414386317311187, |
|
|
-0.24976770655101155, |
|
|
-2.985256521212579, |
|
|
], |
|
|
"q99": [2.9380437893235993, 1.5403010739503078, 2.9746912523985025], |
|
|
"std": [1.7866587696177456, 0.40620530263065, 1.7288511340250616], |
|
|
}, |
|
|
"gripper": { |
|
|
"max": [0.83056640625], |
|
|
"min": [0.0001499652862548828], |
|
|
"q01": [0.0001499652862548828], |
|
|
"q99": [0.82666015625], |
|
|
}, |
|
|
"joints": { |
|
|
"max": [ |
|
|
0.96240234375, |
|
|
1.1162109375, |
|
|
1.1064453125, |
|
|
-0.98095703125, |
|
|
2.30859375, |
|
|
1.576171875, |
|
|
1.7412109375, |
|
|
], |
|
|
"mean": [ |
|
|
0.005913593806326389, |
|
|
0.1877261847257614, |
|
|
0.04653879255056381, |
|
|
-2.0529513359069824, |
|
|
-0.011298442259430885, |
|
|
0.6185526251792908, |
|
|
-0.01701134257018566, |
|
|
], |
|
|
"min": [ |
|
|
-0.8330078125, |
|
|
-0.74658203125, |
|
|
-0.8642578125, |
|
|
-2.892578125, |
|
|
-1.390625, |
|
|
-0.24658203125, |
|
|
-2.953125, |
|
|
], |
|
|
"q01": [ |
|
|
-0.41015625, |
|
|
-0.5302734375, |
|
|
-0.6455078125, |
|
|
-2.57421875, |
|
|
-0.76416015625, |
|
|
-0.0386962890625, |
|
|
-1.435546875, |
|
|
], |
|
|
"q99": [ |
|
|
0.66455078125, |
|
|
0.9501953125, |
|
|
0.7529296875, |
|
|
-1.251953125, |
|
|
0.75244140625, |
|
|
1.2314453125, |
|
|
1.384765625, |
|
|
], |
|
|
"std": [ |
|
|
0.17915399372577667, |
|
|
0.32234326004981995, |
|
|
0.26069700717926025, |
|
|
0.31767210364341736, |
|
|
0.205329030752182, |
|
|
0.33385637402534485, |
|
|
0.6263682842254639, |
|
|
], |
|
|
}, |
|
|
"translation": { |
|
|
"max": [0.5747738480567932, 0.3972920775413513, 0.7443570494651794], |
|
|
"mean": [ |
|
|
0.3331542909145355, |
|
|
0.019357483834028244, |
|
|
0.37330344319343567, |
|
|
], |
|
|
"min": [ |
|
|
0.09978063404560089, |
|
|
-0.29593944549560547, |
|
|
0.10065606236457825, |
|
|
], |
|
|
"q01": [ |
|
|
0.18437016010284424, |
|
|
-0.25699371099472046, |
|
|
0.15134164690971375, |
|
|
], |
|
|
"q99": [0.543661892414093, 0.29646238684654236, 0.6682320833206177], |
|
|
"std": [ |
|
|
0.07849054038524628, |
|
|
0.12241040915250778, |
|
|
0.1460595279932022, |
|
|
], |
|
|
}, |
|
|
}, |
|
|
} |
|
|
|
|
|
@cached_property |
|
|
def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: |
|
|
if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm): |
|
|
return {} |
|
|
with open(self.config.control_stats_path, "r") as file: |
|
|
stats = yaml.safe_load(file) |
|
|
if self.config.delta_controls: |
|
|
if self.control_io_config.future_controls_sequence_stride_sec is None: |
|
|
horizon = 0.0 |
|
|
else: |
|
|
horizon = self.control_io_config.future_controls_sequence_stride_sec |
|
|
elif self.control_io_config.future_controls_sequence_stride_sec is None: |
|
|
if self.control_io_config.future_controls_sequence_length == 1: |
|
|
horizon = 0.0 |
|
|
else: |
|
|
raise NotImplementedError() |
|
|
else: |
|
|
horizon = ( |
|
|
self.control_io_config.future_controls_sequence_length |
|
|
* self.control_io_config.future_controls_sequence_stride_sec |
|
|
) |
|
|
key = f"horizon_{round(horizon, 2)}s" |
|
|
if key in stats: |
|
|
stats = stats[key] |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]" |
|
|
) |
|
|
return stats |
|
|
|
|
|
@cached_property |
|
|
def dataset_names(self) -> List[str]: |
|
|
if ( |
|
|
is_global_norm(self.config.rotation_norm) |
|
|
and is_global_norm(self.config.obs_rotation_norm) |
|
|
and is_global_norm(self.config.translation_norm) |
|
|
and is_global_norm(self.config.obs_translation_norm) |
|
|
): |
|
|
return ["ANY"] |
|
|
return list(set(self._control_stats.keys()) | set(self._observation_stats.keys())) |
|
|
|
|
|
|
|
|
def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding |
|
|
w.r.t. the 0-th element preceding the sequence |
|
|
Ex: |
|
|
Sequence of points: T1, T2, T3, T4 |
|
|
`translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame, |
|
|
implicitly encoded in T0T1 |
|
|
Output: T0T1, T0T2, T0T3, T0T4 |
|
|
|
|
|
Args: |
|
|
translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S |
|
|
corresponds to the sequence dimension |
|
|
Returns: |
|
|
torch.Tensor of the same shape as translation_sequence, containing delta translations |
|
|
""" |
|
|
assert translation_sequence.ndim >= 3, translation_sequence.shape |
|
|
delta_translations = torch.cumsum(translation_sequence, dim=-2) |
|
|
return delta_translations |
|
|
|
|
|
|
|
|
class RegressionProcessor(VLAMProcessor): |
|
|
def policy_control_plan_from_model_target( |
|
|
self, target: RoboticsTarget, dataset_name: np.ndarray |
|
|
) -> RoboticsControlPlan: |
|
|
translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key="translation") |
|
|
rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key="rotation") |
|
|
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT) |
|
|
gripper_prob = target.gripper |
|
|
if self.config.delta_controls: |
|
|
translation_m = delta_to_relative_translations(translation_m) |
|
|
rotmat = delta_to_relative_rotations(rotmat) |
|
|
return RoboticsControlPlan( |
|
|
translation_m=translation_m, |
|
|
rotmat=rotmat, |
|
|
gripper_prob=gripper_prob, |
|
|
valid_mask=target.valid_mask, |
|
|
) |
|
|
|
|
|
def policy_control_plan_from_model_output( |
|
|
self, |
|
|
model_output: RoboticsOutput, |
|
|
dataset_name: np.ndarray, |
|
|
valid_mask: torch.Tensor, |
|
|
) -> RoboticsControlPlan: |
|
|
"""Called during inference to create control plan from model output""" |
|
|
translation_m = self.unnormalize( |
|
|
model_output.translation, dataset_name=dataset_name, key="translation" |
|
|
) |
|
|
rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key="rotation") |
|
|
rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True) |
|
|
gripper_prob = torch.sigmoid(model_output.gripper) |
|
|
if self.config.delta_controls: |
|
|
translation_m = delta_to_relative_translations(translation_m) |
|
|
rotmat = delta_to_relative_rotations(rotmat) |
|
|
return RoboticsControlPlan( |
|
|
translation_m=translation_m, |
|
|
rotmat=rotmat, |
|
|
gripper_prob=gripper_prob, |
|
|
valid_mask=valid_mask, |
|
|
) |
|
|
|
|
|
|
|
|
class PiZeroFlowMatchingProcessor(RegressionProcessor): |
|
|
def __init__(self, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.generator: torch.Generator = torch.Generator() |
|
|
|
|
|
@cached_property |
|
|
def beta_distribution(self) -> torch.distributions.Beta: |
|
|
return torch.distributions.Beta( |
|
|
self.config.distribution_hyperparams.get("alpha", 1.5), |
|
|
self.config.distribution_hyperparams.get("beta", 1.0), |
|
|
) |
|
|
|
|
|
def create_input(self, *args, **kwargs) -> RoboticsFlowInput: |
|
|
"""In practice used only during inference""" |
|
|
inputs = super().create_input(*args, **kwargs) |
|
|
flow_input: FlowInput = self.sample_t0_input(batch_size=1, device=torch.device("cpu")) |
|
|
inputs = RoboticsFlowInput(**inputs.as_json(), flow_input=flow_input[0, ...]) |
|
|
return inputs |
|
|
|
|
|
def sample_timestep(self, batch_size: int) -> torch.Tensor: |
|
|
if self.config.timestep_distribution.lower() == "uniform": |
|
|
eps = 1e-05 |
|
|
sample = (torch.rand(1, generator=self.generator) + torch.arange(batch_size) / batch_size) % ( |
|
|
1 - eps |
|
|
) |
|
|
elif self.config.timestep_distribution.lower() == "beta": |
|
|
sample = self.beta_distribution.sample([batch_size, 1, 1]) |
|
|
sample = (1 - self.config.sig_min) * (1 - sample) |
|
|
else: |
|
|
raise NotImplementedError(self.config.timestep_distribution) |
|
|
sample = sample.view(batch_size, 1, 1) |
|
|
return sample |
|
|
|
|
|
def _psi_t(self, timestep: torch.Tensor, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: |
|
|
return (1 - (1 - self.config.sig_min) * timestep) * x_0 + timestep * x_1 |
|
|
|
|
|
def _dpsi_dt(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: |
|
|
return x_1 - (1 - self.config.sig_min) * x_0 |
|
|
|
|
|
def sample_t0_input(self, batch_size: int, device: torch.device) -> FlowInput: |
|
|
if self.config.r0_distribution == "normal": |
|
|
controls_t0 = torch.randn( |
|
|
[ |
|
|
batch_size, |
|
|
self.config.control_io_config.future_controls_sequence_length, |
|
|
3 + self.rotation_components + 1, |
|
|
], |
|
|
generator=self.generator, |
|
|
).to(device=device) |
|
|
(translation_t0, rotation_t0, gripper_t0) = torch.split( |
|
|
controls_t0, [3, self.rotation_components, 1], dim=-1 |
|
|
) |
|
|
rotation_t0 = normalize_rotation(rotation_t0) |
|
|
elif self.config.r0_distribution == "uniform": |
|
|
controls_t0 = torch.randn( |
|
|
[ |
|
|
batch_size, |
|
|
self.config.control_io_config.future_controls_sequence_length, |
|
|
4, |
|
|
], |
|
|
generator=self.generator, |
|
|
).to(device=device) |
|
|
(translation_t0, gripper_t0) = torch.split(controls_t0, [3, 1], dim=-1) |
|
|
rotation_t0 = convert_rotation( |
|
|
roma.random_unitquat( |
|
|
( |
|
|
batch_size, |
|
|
self.config.control_io_config.future_controls_sequence_length, |
|
|
), |
|
|
device=device, |
|
|
), |
|
|
self.config.rotation_format, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError(self.config.r0_distribution) |
|
|
if self.config.rotation_format == RotationFormat.QUATERNION: |
|
|
rotation_t0 = quaternion_half_cover(rotation_t0) |
|
|
timestep = torch.zeros([batch_size, 1, 1], device=device) |
|
|
return FlowInput( |
|
|
timestep=timestep, |
|
|
translation_t0=translation_t0, |
|
|
rotation_t0=rotation_t0, |
|
|
gripper_t0=gripper_t0, |
|
|
translation_t=None, |
|
|
rotation_t=None, |
|
|
gripper_t=None, |
|
|
) |
|
|
|
|
|
def policy_control_plan_from_model_output( |
|
|
self, |
|
|
model_output: RoboticsOutput, |
|
|
dataset_name: np.ndarray, |
|
|
valid_mask: torch.Tensor, |
|
|
) -> RoboticsControlPlan: |
|
|
if self.config.translation_norm == Normalization.NONE or is_mean_norm(self.config.translation_norm): |
|
|
model_output = model_output.replace(translation=torch.clamp(model_output.translation, -1, 1)) |
|
|
if self.config.rotation_norm == Normalization.NONE or is_mean_norm(self.config.rotation_norm): |
|
|
model_output = model_output.replace(rotation=torch.clamp(model_output.rotation, -1, 1)) |
|
|
control_plan = super().policy_control_plan_from_model_output(model_output, dataset_name, valid_mask) |
|
|
control_plan = control_plan.replace(gripper_prob=torch.clamp(model_output.gripper, 0, 1)) |
|
|
return control_plan |
|
|
|
|
|
|
|
|
def make_causal_mask(shape: Sequence[int]) -> torch.Tensor: |
|
|
""" |
|
|
Create a causal attention mask of shape `shape` |
|
|
Args: |
|
|
shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len] |
|
|
Returns: |
|
|
torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend |
|
|
to the corresponding column (i.e. key) |
|
|
|
|
|
Example: |
|
|
shape = (3, 5) -> Mask the upper triangular part |
|
|
[ |
|
|
[ 1, 0, 0, 0, 0], |
|
|
[ 1, 1, 0, 0, 0], |
|
|
[ 1, 1, 1, 0, 0] |
|
|
] |
|
|
""" |
|
|
return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0) |
|
|
|
|
|
|
|
|
def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Enable full bi-directional attention in `attn_mask` inside specific blocks |
|
|
Args: |
|
|
attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool |
|
|
where False values indicate disabled attention |
|
|
full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate |
|
|
positions where full bi-directional attention should be enabled |
|
|
|
|
|
Example: |
|
|
1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, |
|
|
1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, |
|
|
1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, |
|
|
1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0, |
|
|
1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, |
|
|
1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, |
|
|
1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, |
|
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, |
|
|
|
|
|
""" |
|
|
assert full_attn.dtype == torch.bool, full_attn.dtype |
|
|
assert full_attn.ndim == 1, full_attn.shape |
|
|
assert full_attn.shape[0] == attn_mask.shape[-2], f"{full_attn.shape[0]}, {attn_mask.shape}" |
|
|
if attn_mask.shape[-1] != attn_mask.shape[-2]: |
|
|
raise NotImplementedError("Only self-attention supported right now.") |
|
|
x = full_attn.view(-1, 1) & full_attn.view(1, -1) |
|
|
x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]]) |
|
|
x = torch.cumprod(x, dim=1).to(dtype=torch.bool) |
|
|
x = x & x.permute(1, 0) |
|
|
mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn |
|
|
mask_indices = torch.where(mask_positions)[0] |
|
|
x[mask_indices, mask_indices] = 0 |
|
|
attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1]) |
|
|
return attn_mask |
|
|
|
|
|
|
|
|
IGNORE_INDEX = -100 |
|
|
|
|
|
|
|
|
class PaliGemmaProcessor(VLMProcessor): |
|
|
def __init__( |
|
|
self, |
|
|
config: PaliGemmaProcessorConfig, |
|
|
hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor, |
|
|
**kwargs, |
|
|
): |
|
|
del kwargs |
|
|
super().__init__(config) |
|
|
self.hf_processor = hf_processor |
|
|
self.hf_processor.image_processor.size = dict(self.config.image_sizes["main"].as_json()) |
|
|
self.hf_processor.image_seq_length = self.config.num_image_tokens["main"] |
|
|
self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens["main"] |
|
|
self.bos_id: int = self.tokenizer.bos_token_id |
|
|
self.eos_id: int = self.tokenizer.eos_token_id |
|
|
self.sep_token = "\n" |
|
|
self.sep_id: int = self.tokenizer( |
|
|
self.sep_token, |
|
|
padding=False, |
|
|
add_special_tokens=False, |
|
|
return_attention_mask=False, |
|
|
)["input_ids"][0] |
|
|
self.image_token_id: int = self.tokenizer( |
|
|
self.config.image_token, |
|
|
padding=False, |
|
|
add_special_tokens=False, |
|
|
return_attention_mask=False, |
|
|
)["input_ids"][0] |
|
|
self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values()) |
|
|
self.bbox_pattern = re.compile( |
|
|
"\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]" |
|
|
) |
|
|
|
|
|
def preprocess_inputs( |
|
|
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] |
|
|
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at |
|
|
https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs |
|
|
Chat must be always made of separate messages from user and model, always starting with user |
|
|
|
|
|
<image><image> ... <bos><instruction><sep><assistant><sep><instruction><sep><assistant>...<eos> |
|
|
|
|
|
Args: |
|
|
chat: List[str] of even size where each entry corresponds to a different turn in the conversation |
|
|
images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys |
|
|
in the Dict and the List corresponds to history of images |
|
|
""" |
|
|
for key, value in images.items(): |
|
|
if not isinstance(value, list): |
|
|
raise TypeError(f"Camera {key} contains values of type {type(value)} instead of list") |
|
|
(input_ids, target_ids) = ([], []) |
|
|
for i, text in enumerate(chat): |
|
|
text = text.replace(self.sep_token, " ").replace("<image>", "") |
|
|
text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text) |
|
|
turn_input_ids: List[int] = self.tokenizer( |
|
|
text, |
|
|
padding=False, |
|
|
add_special_tokens=False, |
|
|
return_attention_mask=False, |
|
|
)["input_ids"] |
|
|
if i % 2 == 0: |
|
|
turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids) |
|
|
else: |
|
|
turn_target_ids = turn_input_ids |
|
|
if i != len(chat) - 1: |
|
|
turn_input_ids = turn_input_ids + [self.sep_id] |
|
|
turn_target_ids = turn_target_ids + [IGNORE_INDEX] |
|
|
input_ids = input_ids + turn_input_ids |
|
|
target_ids = target_ids + turn_target_ids |
|
|
input_ids = [self.bos_id] + input_ids + [self.eos_id] |
|
|
target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id] |
|
|
image_tokens = self.image_tokens |
|
|
if self.config.max_language_tokens > 0: |
|
|
input_ids = input_ids[: self.config.max_language_tokens] |
|
|
target_ids = target_ids[: self.config.max_language_tokens] |
|
|
input_ids = image_tokens + input_ids |
|
|
target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids |
|
|
input_ids = torch.tensor(input_ids, dtype=torch.int64) |
|
|
target_ids = torch.tensor(target_ids, dtype=torch.int64) |
|
|
image_tensors: Dict[str, torch.Tensor] = { |
|
|
f"{camera_name}.siglip": self.hf_processor.image_processor( |
|
|
camera_images, |
|
|
size=self.config.image_sizes[camera_name].as_json(), |
|
|
return_tensors="pt", |
|
|
)["pixel_values"] |
|
|
for (camera_name, camera_images) in images.items() |
|
|
} |
|
|
attn_mask = make_causal_mask([len(input_ids), len(input_ids)]) |
|
|
attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX) |
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"target_ids": target_ids, |
|
|
"images": image_tensors, |
|
|
"attn_mask": attn_mask, |
|
|
} |
|
|
|
|
|
@property |
|
|
def tokenizer(self) -> transformers.PreTrainedTokenizerBase: |
|
|
return self.hf_processor.tokenizer |
|
|
|
|
|
@staticmethod |
|
|
def _bbox_to_loc_tokens(match: str) -> str: |
|
|
""" |
|
|
https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/ |
|
|
""" |
|
|
floats = list(map(float, match.groups())) |
|
|
transformed = [f"<loc{np.clip(round(num * 1024), 0, 1023):04d}>" for num in floats] |
|
|
return f"[{', '.join(transformed)}]" |
|
|
|
|
|
@property |
|
|
def image_sizes(self) -> Dict[str, ImageSizeConfig]: |
|
|
return self.config.image_sizes |
|
|
|
|
|
|
|
|
class PaliGemmaDepthProcessor(PaliGemmaProcessor): |
|
|
def __init__( |
|
|
self, |
|
|
config: PaliGemmaProcessorConfig, |
|
|
hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor, |
|
|
depth_tokens: int, |
|
|
): |
|
|
super().__init__(config, hf_processor) |
|
|
vocab_size = len(self.tokenizer) |
|
|
self.depth_token_ids = np.arange(vocab_size - depth_tokens, vocab_size) |
|
|
self.depth_input_transforms = { |
|
|
camera_name: torchvision.transforms.v2.Compose( |
|
|
[ |
|
|
torchvision.transforms.v2.Resize( |
|
|
size=(camera_image_size.height, camera_image_size.width), |
|
|
interpolation=torchvision.transforms.v2.InterpolationMode.BICUBIC, |
|
|
max_size=None, |
|
|
antialias=True, |
|
|
), |
|
|
torchvision.transforms.v2.ToTensor(), |
|
|
torchvision.transforms.v2.Normalize( |
|
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
|
), |
|
|
] |
|
|
) |
|
|
for (camera_name, camera_image_size) in self.config.image_sizes.items() |
|
|
} |
|
|
|
|
|
def preprocess_inputs( |
|
|
self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] |
|
|
) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: |
|
|
inputs = super().preprocess_inputs(chat=chat, images=images) |
|
|
depth_images: Dict[str, torch.Tensor] = { |
|
|
f"{camera_name}.depth": torch.stack( |
|
|
self.depth_input_transforms[camera_name](camera_images), dim=0 |
|
|
) |
|
|
for (camera_name, camera_images) in images.items() |
|
|
} |
|
|
inputs["images"] = {**inputs["images"], **depth_images} |
|
|
return inputs |
|
|
|
|
|
@property |
|
|
def num_depth_tokens(self) -> int: |
|
|
return len(self.depth_token_ids) |
|
|
|