Timsty's picture
Upload folder using huggingface_hub
e94400c verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, ClassVar, Literal
import albumentations as A
import cv2
import numpy as np
import torch
import torchvision.transforms.v2 as T
from einops import rearrange
from pydantic import Field, PrivateAttr, field_validator
from PIL import Image
from ..schema import DatasetMetadata
from .base import ModalityTransform
class VideoTransform(ModalityTransform):
# Configurable attributes
backend: str = Field(
default="torchvision", description="The backend to use for the transformations"
)
# Model variables
_train_transform: Callable | None = PrivateAttr(default=None)
_eval_transform: Callable | None = PrivateAttr(default=None)
_original_resolutions: dict[str, tuple[int, int]] = PrivateAttr(default_factory=dict)
# Model constants
_INTERPOLATION_MAP: ClassVar[dict[str, dict[str, Any]]] = PrivateAttr(
{
"nearest": {
"albumentations": cv2.INTER_NEAREST,
"torchvision": T.InterpolationMode.NEAREST,
},
"linear": {
"albumentations": cv2.INTER_LINEAR,
"torchvision": T.InterpolationMode.BILINEAR,
},
"cubic": {
"albumentations": cv2.INTER_CUBIC,
"torchvision": T.InterpolationMode.BICUBIC,
},
"area": {
"albumentations": cv2.INTER_AREA,
"torchvision": None, # Torchvision does not support this interpolation mode
},
"lanczos4": {
"albumentations": cv2.INTER_LANCZOS4, # Lanczos with a 4x4 filter
"torchvision": T.InterpolationMode.LANCZOS, # Torchvision does not specify filter size, might be different from 4x4
},
"linear_exact": {
"albumentations": cv2.INTER_LINEAR_EXACT,
"torchvision": None, # Torchvision does not support this interpolation mode
},
"nearest_exact": {
"albumentations": cv2.INTER_NEAREST_EXACT,
"torchvision": T.InterpolationMode.NEAREST_EXACT,
},
"max": {
"albumentations": cv2.INTER_MAX,
"torchvision": None,
},
}
)
@property
def train_transform(self) -> Callable:
assert (
self._train_transform is not None
), "Transform is not set. Please call set_metadata() before calling apply()."
return self._train_transform
@train_transform.setter
def train_transform(self, value: Callable):
self._train_transform = value
@property
def eval_transform(self) -> Callable | None:
return self._eval_transform
@eval_transform.setter
def eval_transform(self, value: Callable | None):
self._eval_transform = value
@property
def original_resolutions(self) -> dict[str, tuple[int, int]]:
assert (
self._original_resolutions is not None
), "Original resolutions are not set. Please call set_metadata() before calling apply()."
return self._original_resolutions
@original_resolutions.setter
def original_resolutions(self, value: dict[str, tuple[int, int]]):
self._original_resolutions = value
def check_input(self, data: dict[str, Any]):
if self.backend == "torchvision":
for key in self.apply_to:
assert isinstance(data[key], torch.Tensor), f"Video {key} is not a torch tensor"
assert data[key].ndim in [
4,
5,
], f"Expected video {key} to have 4 or 5 dimensions (T, C, H, W or T, B, C, H, W), got {data[key].ndim}"
elif self.backend == "albumentations":
for key in self.apply_to:
assert isinstance(data[key], np.ndarray), f"Video {key} is not a numpy array"
assert data[key].ndim in [
4,
5,
], f"Expected video {key} to have 4 or 5 dimensions (T, C, H, W or T, B, C, H, W), got {data[key].ndim}"
else:
raise ValueError(f"Backend {self.backend} not supported")
def set_metadata(self, dataset_metadata: DatasetMetadata):
super().set_metadata(dataset_metadata)
self.original_resolutions = {}
for key in self.apply_to:
split_keys = key.split(".")
assert len(split_keys) == 2, f"Invalid key: {key}. Expected format: modality.key"
sub_key = split_keys[1]
if sub_key in dataset_metadata.modalities.video:
self.original_resolutions[key] = dataset_metadata.modalities.video[
sub_key
].resolution
else:
raise ValueError(
f"Video key {sub_key} not found in dataset metadata. Available keys: {dataset_metadata.modalities.video.keys()}"
)
train_transform = self.get_transform(mode="train")
eval_transform = self.get_transform(mode="eval")
if self.backend == "albumentations":
self.train_transform = A.ReplayCompose(transforms=[train_transform]) # type: ignore
if eval_transform is not None:
self.eval_transform = A.ReplayCompose(transforms=[eval_transform]) # type: ignore
else:
assert train_transform is not None, "Train transform must be set"
self.train_transform = train_transform
self.eval_transform = eval_transform
def apply(self, data: dict[str, Any]) -> dict[str, Any]:
if self.training:
transform = self.train_transform
else:
transform = self.eval_transform
if transform is None:
return data
assert (
transform is not None
), "Transform is not set. Please call set_metadata() before calling apply()."
try:
self.check_input(data)
except AssertionError as e:
raise ValueError(
f"Input data does not match the expected format for {self.__class__.__name__}: {e}"
) from e
# Concatenate views
views = [data[key] for key in self.apply_to]
num_views = len(views)
is_batched = views[0].ndim == 5
bs = views[0].shape[0] if is_batched else 1
if isinstance(views[0], torch.Tensor):
views = torch.cat(views, 0)
elif isinstance(views[0], np.ndarray):
views = np.concatenate(views, 0)
else:
raise ValueError(f"Unsupported view type: {type(views[0])}")
if is_batched:
views = rearrange(views, "(v b) t c h w -> (v b t) c h w", v=num_views, b=bs)
# Apply the transform
if self.backend == "torchvision":
views = transform(views)
elif self.backend == "albumentations":
assert isinstance(transform, A.ReplayCompose), "Transform must be a ReplayCompose"
first_frame = views[0]
transformed = transform(image=first_frame)
replay_data = transformed["replay"]
transformed_first_frame = transformed["image"]
if len(views) > 1:
# Apply the same transformations to the rest of the frames
transformed_frames = [
transform.replay(replay_data, image=frame)["image"] for frame in views[1:]
]
# Add the first frame back
transformed_frames = [transformed_first_frame] + transformed_frames
else:
# If there is only one frame, just make a list with one frame
transformed_frames = [transformed_first_frame]
# Delete the replay data to save memory
del replay_data
views = np.stack(transformed_frames, 0)
else:
raise ValueError(f"Backend {self.backend} not supported")
# Split views
if is_batched:
views = rearrange(views, "(v b t) c h w -> v b t c h w", v=num_views, b=bs)
else:
views = rearrange(views, "(v t) c h w -> v t c h w", v=num_views)
for key, view in zip(self.apply_to, views):
data[key] = view
return data
@classmethod
def _validate_interpolation(cls, interpolation: str):
if interpolation not in cls._INTERPOLATION_MAP:
raise ValueError(f"Interpolation mode {interpolation} not supported")
def _get_interpolation(self, interpolation: str, backend: str = "torchvision"):
"""
Get the interpolation mode for the given backend.
Args:
interpolation (str): The interpolation mode.
backend (str): The backend to use.
Returns:
Any: The interpolation mode for the given backend.
"""
return self._INTERPOLATION_MAP[interpolation][backend]
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
raise NotImplementedError(
"set_transform is not implemented for VideoTransform. Please implement this function to set the transforms."
)
class VideoCrop(VideoTransform):
height: int | None = Field(default=None, description="The height of the input image")
width: int | None = Field(default=None, description="The width of the input image")
scale: float = Field(
...,
description="The scale of the crop. The crop size is (width * scale, height * scale)",
)
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
"""Get the transform for the given mode.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable: If mode is "train", return a random crop transform. If mode is "eval", return a center crop transform.
"""
# 1. Check the input resolution
assert (
len(set(self.original_resolutions.values())) == 1
), f"All video keys must have the same resolution, got: {self.original_resolutions}"
if self.height is None:
assert self.width is None, "Height and width must be either both provided or both None"
self.width, self.height = self.original_resolutions[self.apply_to[0]]
else:
assert (
self.width is not None
), "Height and width must be either both provided or both None"
# 2. Create the transform
size = (int(self.height * self.scale), int(self.width * self.scale))
if self.backend == "torchvision":
if mode == "train":
return T.RandomCrop(size)
elif mode == "eval":
return T.CenterCrop(size)
else:
raise ValueError(f"Crop mode {mode} not supported")
elif self.backend == "albumentations":
if mode == "train":
return A.RandomCrop(height=size[0], width=size[1], p=1)
elif mode == "eval":
return A.CenterCrop(height=size[0], width=size[1], p=1)
else:
raise ValueError(f"Crop mode {mode} not supported")
else:
raise ValueError(f"Backend {self.backend} not supported")
def check_input(self, data: dict[str, Any]):
super().check_input(data)
# Check the input resolution
for key in self.apply_to:
if self.backend == "torchvision":
height, width = data[key].shape[-2:]
elif self.backend == "albumentations":
height, width = data[key].shape[-3:-1]
else:
raise ValueError(f"Backend {self.backend} not supported")
assert (
height == self.height and width == self.width
), f"Video {key} has invalid shape {height, width}, expected {self.height, self.width}"
class VideoResize(VideoTransform):
height: int = Field(..., description="The height of the resize")
width: int = Field(..., description="The width of the resize")
interpolation: str = Field(default="linear", description="The interpolation mode")
antialias: bool = Field(default=True, description="Whether to apply antialiasing")
@field_validator("interpolation")
def validate_interpolation(cls, v):
cls._validate_interpolation(v)
return v
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
"""Get the resize transform. Same transform for both train and eval.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable: The resize transform.
"""
interpolation = self._get_interpolation(self.interpolation, self.backend)
if interpolation is None:
raise ValueError(
f"Interpolation mode {self.interpolation} not supported for torchvision"
)
if self.backend == "torchvision":
size = (self.height, self.width)
return T.Resize(size, interpolation=interpolation, antialias=self.antialias)
elif self.backend == "albumentations":
return A.Resize(
height=self.height,
width=self.width,
interpolation=interpolation,
p=1,
)
else:
raise ValueError(f"Backend {self.backend} not supported")
class VideoRandomRotation(VideoTransform):
degrees: float | tuple[float, float] = Field(
..., description="The degrees of the random rotation"
)
interpolation: str = Field("linear", description="The interpolation mode")
@field_validator("interpolation")
def validate_interpolation(cls, v):
cls._validate_interpolation(v)
return v
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
"""Get the random rotation transform, only used in train mode.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable | None: The random rotation transform. None for eval mode.
"""
if mode == "eval":
return None
interpolation = self._get_interpolation(self.interpolation, self.backend)
if interpolation is None:
raise ValueError(
f"Interpolation mode {self.interpolation} not supported for torchvision"
)
if self.backend == "torchvision":
return T.RandomRotation(self.degrees, interpolation=interpolation) # type: ignore
elif self.backend == "albumentations":
return A.Rotate(limit=self.degrees, interpolation=interpolation, p=1)
else:
raise ValueError(f"Backend {self.backend} not supported")
class VideoHorizontalFlip(VideoTransform):
p: float = Field(..., description="The probability of the horizontal flip")
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
"""Get the horizontal flip transform, only used in train mode.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable | None: If mode is "train", return a horizontal flip transform. If mode is "eval", return None.
"""
if mode == "eval":
return None
if self.backend == "torchvision":
return T.RandomHorizontalFlip(self.p)
elif self.backend == "albumentations":
return A.HorizontalFlip(p=self.p)
else:
raise ValueError(f"Backend {self.backend} not supported")
class VideoGrayscale(VideoTransform):
p: float = Field(..., description="The probability of the grayscale transformation")
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
"""Get the grayscale transform, only used in train mode.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None.
"""
if mode == "eval":
return None
if self.backend == "torchvision":
return T.RandomGrayscale(self.p)
elif self.backend == "albumentations":
return A.ToGray(p=self.p)
else:
raise ValueError(f"Backend {self.backend} not supported")
class VideoColorJitter(VideoTransform):
brightness: float | tuple[float, float] = Field(
..., description="The brightness of the color jitter"
)
contrast: float | tuple[float, float] = Field(
..., description="The contrast of the color jitter"
)
saturation: float | tuple[float, float] = Field(
..., description="The saturation of the color jitter"
)
hue: float | tuple[float, float] = Field(..., description="The hue of the color jitter")
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
"""Get the color jitter transform, only used in train mode.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable | None: If mode is "train", return a color jitter transform. If mode is "eval", return None.
"""
if mode == "eval":
return None
if self.backend == "torchvision":
return T.ColorJitter(
brightness=self.brightness,
contrast=self.contrast,
saturation=self.saturation,
hue=self.hue,
)
elif self.backend == "albumentations":
return A.ColorJitter(
brightness=self.brightness,
contrast=self.contrast,
saturation=self.saturation,
hue=self.hue,
p=1,
)
else:
raise ValueError(f"Backend {self.backend} not supported")
class VideoRandomGrayscale(VideoTransform):
p: float = Field(..., description="The probability of the grayscale transformation")
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
"""Get the grayscale transform, only used in train mode.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable | None: If mode is "train", return a grayscale transform. If mode is "eval", return None.
"""
if mode == "eval":
return None
if self.backend == "torchvision":
return T.RandomGrayscale(self.p)
elif self.backend == "albumentations":
return A.ToGray(p=self.p)
else:
raise ValueError(f"Backend {self.backend} not supported")
class VideoRandomPosterize(VideoTransform):
bits: int = Field(..., description="The number of bits to posterize the image")
p: float = Field(..., description="The probability of the posterize transformation")
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable | None:
"""Get the posterize transform, only used in train mode.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable | None: If mode is "train", return a posterize transform. If mode is "eval", return None.
"""
if mode == "eval":
return None
if self.backend == "torchvision":
return T.RandomPosterize(bits=self.bits, p=self.p)
elif self.backend == "albumentations":
return A.Posterize(num_bits=self.bits, p=self.p)
else:
raise ValueError(f"Backend {self.backend} not supported")
class VideoToTensor(VideoTransform):
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
"""Get the to tensor transform. Same transform for both train and eval.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable: The to tensor transform.
"""
if self.backend == "torchvision":
return self.__class__.to_tensor
else:
raise ValueError(f"Backend {self.backend} not supported")
def check_input(self, data: dict):
"""Check if the input data has the correct shape.
Expected video shape: [T, H, W, C], dtype np.uint8
"""
for key in self.apply_to:
assert key in data, f"Key {key} not found in data. Available keys: {data.keys()}"
assert data[key].ndim in [
4,
5,
], f"Video {key} must have 4 or 5 dimensions, got {data[key].ndim}"
assert (
data[key].dtype == np.uint8
), f"Video {key} must have dtype uint8, got {data[key].dtype}"
input_resolution = data[key].shape[-3:-1][::-1]
if key in self.original_resolutions:
expected_resolution = self.original_resolutions[key]
else:
expected_resolution = input_resolution
assert (
input_resolution == expected_resolution
), f"Video {key} has invalid resolution {input_resolution}, expected {expected_resolution}. Full shape: {data[key].shape}"
@staticmethod
def to_tensor(frames: np.ndarray) -> torch.Tensor:
"""Convert numpy array to tensor efficiently.
Args:
frames: numpy array of shape [T, H, W, C] in uint8 format
Returns:
tensor of shape [T, C, H, W] in range [0, 1]
"""
frames_tensor = torch.from_numpy(frames).to(torch.float32) / 255.0
return frames_tensor.permute(0, 3, 1, 2) # [T, C, H, W]
class VideoToNumpy(VideoTransform):
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
"""Get the to numpy transform. Same transform for both train and eval.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable: The to numpy transform.
"""
if self.backend == "torchvision":
return self.__class__.to_numpy
else:
raise ValueError(f"Backend {self.backend} not supported")
@staticmethod
def to_numpy(frames: torch.Tensor) -> np.ndarray:
"""Convert tensor back to numpy array efficiently.
Args:
frames: tensor of shape [T, C, H, W] in range [0, 1]
Returns:
numpy array of shape [T, H, W, C] in uint8 format
"""
return (frames.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
class VideoToPIL(VideoTransform):
def get_transform(self, mode: Literal["train", "eval"] = "train") -> Callable:
"""Get the to PIL transform. Same transform for both train and eval.
Args:
mode (Literal["train", "eval"]): The mode to get the transform for.
Returns:
Callable: The to PIL transform.
"""
if self.backend == "torchvision":
return self.__class__.to_pil
else:
raise ValueError(f"Backend {self.backend} not supported")
@staticmethod
def to_pil(frames: torch.Tensor) -> Image.Image:
"""Convert tensor back to PIL Image.
Args:
frames: tensor of shape [T, C, H, W] in range [0, 1]
Returns:
PIL Image of shape [T, H, W, C] in uint8 format
"""
# video PIL format?
return Image.fromarray((frames.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy())