|
|
import functools |
|
|
|
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
import openpi.shared.array_typing as at |
|
|
|
|
|
|
|
|
@functools.partial(jax.jit, static_argnums=(1, 2, 3)) |
|
|
@at.typecheck |
|
|
def resize_with_pad( |
|
|
images: at.UInt8[at.Array, "*b h w c"] | at.Float[at.Array, "*b h w c"], |
|
|
height: int, |
|
|
width: int, |
|
|
method: jax.image.ResizeMethod = jax.image.ResizeMethod.LINEAR, |
|
|
) -> at.UInt8[at.Array, "*b {height} {width} c"] | at.Float[at.Array, "*b {height} {width} c"]: |
|
|
"""Replicates tf.image.resize_with_pad. Resizes an image to a target height and width without distortion |
|
|
by padding with black. If the image is float32, it must be in the range [-1, 1]. |
|
|
""" |
|
|
has_batch_dim = images.ndim == 4 |
|
|
if not has_batch_dim: |
|
|
images = images[None] |
|
|
cur_height, cur_width = images.shape[1:3] |
|
|
ratio = max(cur_width / width, cur_height / height) |
|
|
resized_height = int(cur_height / ratio) |
|
|
resized_width = int(cur_width / ratio) |
|
|
resized_images = jax.image.resize( |
|
|
images, (images.shape[0], resized_height, resized_width, images.shape[3]), method=method |
|
|
) |
|
|
if images.dtype == jnp.uint8: |
|
|
|
|
|
resized_images = jnp.round(resized_images).clip(0, 255).astype(jnp.uint8) |
|
|
elif images.dtype == jnp.float32: |
|
|
resized_images = resized_images.clip(-1.0, 1.0) |
|
|
else: |
|
|
raise ValueError(f"Unsupported image dtype: {images.dtype}") |
|
|
|
|
|
pad_h0, remainder_h = divmod(height - resized_height, 2) |
|
|
pad_h1 = pad_h0 + remainder_h |
|
|
pad_w0, remainder_w = divmod(width - resized_width, 2) |
|
|
pad_w1 = pad_w0 + remainder_w |
|
|
padded_images = jnp.pad( |
|
|
resized_images, |
|
|
((0, 0), (pad_h0, pad_h1), (pad_w0, pad_w1), (0, 0)), |
|
|
constant_values=0 if images.dtype == jnp.uint8 else -1.0, |
|
|
) |
|
|
|
|
|
if not has_batch_dim: |
|
|
padded_images = padded_images[0] |
|
|
return padded_images |
|
|
|
|
|
|
|
|
def resize_with_pad_torch( |
|
|
images: torch.Tensor, |
|
|
height: int, |
|
|
width: int, |
|
|
mode: str = "bilinear", |
|
|
) -> torch.Tensor: |
|
|
"""PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion |
|
|
by padding with black. If the image is float32, it must be in the range [-1, 1]. |
|
|
|
|
|
Args: |
|
|
images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] |
|
|
height: Target height |
|
|
width: Target width |
|
|
mode: Interpolation mode ('bilinear', 'nearest', etc.) |
|
|
|
|
|
Returns: |
|
|
Resized and padded tensor with same shape format as input |
|
|
""" |
|
|
|
|
|
if images.shape[-1] <= 4: |
|
|
channels_last = True |
|
|
|
|
|
if images.dim() == 3: |
|
|
images = images.unsqueeze(0) |
|
|
images = images.permute(0, 3, 1, 2) |
|
|
else: |
|
|
channels_last = False |
|
|
if images.dim() == 3: |
|
|
images = images.unsqueeze(0) |
|
|
|
|
|
batch_size, channels, cur_height, cur_width = images.shape |
|
|
|
|
|
|
|
|
ratio = max(cur_width / width, cur_height / height) |
|
|
resized_height = int(cur_height / ratio) |
|
|
resized_width = int(cur_width / ratio) |
|
|
|
|
|
|
|
|
resized_images = F.interpolate( |
|
|
images, size=(resized_height, resized_width), mode=mode, align_corners=False if mode == "bilinear" else None |
|
|
) |
|
|
|
|
|
|
|
|
if images.dtype == torch.uint8: |
|
|
resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) |
|
|
elif images.dtype == torch.float32: |
|
|
resized_images = resized_images.clamp(-1.0, 1.0) |
|
|
else: |
|
|
raise ValueError(f"Unsupported image dtype: {images.dtype}") |
|
|
|
|
|
|
|
|
pad_h0, remainder_h = divmod(height - resized_height, 2) |
|
|
pad_h1 = pad_h0 + remainder_h |
|
|
pad_w0, remainder_w = divmod(width - resized_width, 2) |
|
|
pad_w1 = pad_w0 + remainder_w |
|
|
|
|
|
|
|
|
constant_value = 0 if images.dtype == torch.uint8 else -1.0 |
|
|
padded_images = F.pad( |
|
|
resized_images, |
|
|
(pad_w0, pad_w1, pad_h0, pad_h1), |
|
|
mode="constant", |
|
|
value=constant_value, |
|
|
) |
|
|
|
|
|
|
|
|
if channels_last: |
|
|
padded_images = padded_images.permute(0, 2, 3, 1) |
|
|
if batch_size == 1 and images.shape[0] == 1: |
|
|
padded_images = padded_images.squeeze(0) |
|
|
|
|
|
return padded_images |
|
|
|