jev-aleks's picture
scenedino init
9e15541
import math
import random
from typing import Tuple, Optional
import kornia
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import torchvision.transforms as tf
from scenedino.models.backbones.dino.decoder import NoDecoder
import logging
logger = logging.getLogger("training")
class MultiScaleCropGT_kornia(nn.Module):
"""This class implements multi-scale-crop augmentation for DINO features."""
def __init__(
self,
gt_encoder: nn.Module,
num_views: int = 8,
image_size: Tuple[int, int] = (192, 640),
feature_stride: int = 16,
) -> None:
"""Constructor method.
Args:
num_views (int): Number of view per image. Default 8.
augmentations (Tuple[AugmentationBase2D, ...]): Geometric augmentations to be applied.
feature_stride (int): Stride of the features. Default 16.
"""
# Call super constructor
super(MultiScaleCropGT_kornia, self).__init__()
# GT encoder
self.gt_encoder = gt_encoder
# Save parameters
self.augmentations_per_sample: int = num_views
self.feature_stride: int = feature_stride
# Init augmentations
image_ratio = image_size[0] / image_size[1]
augmentations = (
kornia.augmentation.RandomHorizontalFlip(p=0.5),
#kornia.augmentation.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
kornia.augmentation.RandomResizedCrop(
scale=(0.5, 1.0), size=tuple(image_size), ratio=(image_ratio/1.2, image_ratio*1.2), p=1.0
# Here you need to set your resolution
),
)
self.augmentations: nn.Module = kornia.augmentation.VideoSequential(*augmentations, same_on_frame=True)
@staticmethod
def _affine_transform_valid_pixels(transform: Tensor, mask: Tensor) -> Tensor:
"""Applies affine transform to a mask of ones to estimate valid pixels.
Args:
transform (Tensor): Affine transform of the shape [B, 3, 3]
mask (Tensor): Mask of the shape [B, 1, H, W].
Returns:
valid_pixels (Tensor): Mask of valid pixels of the shape [B, 1, H, W].
"""
# Get shape
H, W = mask.shape[2:] # type: int, int
# Resample mask map
valid_pixels: Tensor = kornia.geometry.warp_perspective(
mask,
transform,
(H, W),
mode="nearest",
)
# Threshold mask
valid_pixels = torch.where( # type: ignore
valid_pixels > 0.999, torch.ones_like(valid_pixels), torch.zeros_like(valid_pixels)
)
return valid_pixels
def _accumulate_predictions(self, features: Tensor, transforms: Tensor) -> Tensor:
"""Accumulates features over multiple predictions.
Args:
features (Tensor): Feature predictions of the shape [B, num_views, H, W].
transforms (Tensor): Affine transformations of the shape [B, num_views, 3, 3].
Returns:
optical_flow_predictions_accumulated (Tensor): Accumulated optical flow of the shape [B, 2, H, W].
"""
# Get shape
B, N, C, H, W = features.shape # type: int, int, int, int, int
# Get base and augmented views
features_base = features[:, -2:]
features_augmented = features[:, :-2]
# Combine batch dimension and view dimension
features_augmented = features_augmented.flatten(0, 1)
transforms = transforms.flatten(0, 1)
# Rescale transformation
transforms[:, 0, -1] = transforms[:, 0, -1] #/ float(self.feature_stride)
transforms[:, 1, -1] = transforms[:, 1, -1] #/ float(self.feature_stride)
# Invert transformations
transforms_inv: Tensor = torch.inverse(transforms)
# Resample optical flow map
features_resampled: Tensor = kornia.geometry.warp_perspective(
features_augmented,
transforms_inv,
(H, W),
mode="bilinear",
)
# Separate batch and view dimension again
features_resampled = features_resampled.reshape(B, -1, C, H, W)
# Add base views
features_resampled = torch.cat((features_resampled, features_base), dim=1)
# Reverse flip
features_resampled[:, -2] = features_resampled[:, -2].flip(dims=(-1,))
# Compute valid pixels
mask: Tensor = torch.ones(
B, N - 2, 1, H, W, dtype=features_resampled.dtype, device=features_resampled.device
)
mask = mask.flatten(0, 1)
valid_pixels: Tensor = self._affine_transform_valid_pixels(transforms_inv, mask)
valid_pixels = valid_pixels.reshape(B, N - 2, 1, H, W)
valid_pixels = F.pad(valid_pixels, (0, 0, 0, 0, 0, 0, 0, 2), value=1)
# Set invalid flow vectors to zero
features_resampled[valid_pixels.repeat(1, 1, C, 1, 1) == 0.0] = torch.nan
# Average optical flow over different views given the sum valid pixels for the specific pixel
# logger.info(features_resampled.shape)
return features_resampled.nanmean(dim=1)
def _get_augmentations(self, images: Tensor) -> Tuple[Tensor, Tensor]:
"""Forward pass generates different augmentations of the input images.
Args:
images (Tensor): Images of the shape [B, 3, H, W]
Returns:
images_augmented (Tensor): Augmented images of the shape [B, N, H, W].
transforms (Tensor): Transformations of the shape [B, N, 3, 3].
"""
# Add dummy dimension shape is [B, num_views, 3, H, W]
images = images[:, None]
# Init tensor to store transformations
transformations: Tensor = torch.empty(
images.shape[0], self.augmentations_per_sample - 2, 3, 3, dtype=torch.float32, device=images.device
)
# Init tensor to store augmented images
images_augmented: Tensor = torch.empty_like(images)
images_augmented = images_augmented[:, None].repeat_interleave(self.augmentations_per_sample, dim=1)
# Save original and flipped images
images_augmented[:, -1] = images.clone()
images_augmented[:, -2] = images.clone().flip(dims=(-1,))
# Apply geometric augmentations
for index in range(images.shape[0]):
images_repeated: Tensor = images[index][None].repeat_interleave(self.augmentations_per_sample - 2, dim=0)
images_augmented[index, :-2] = self.augmentations(images_repeated)
transformations[index] = self.augmentations.get_transformation_matrix(
images_repeated, self.augmentations._params
)
return images_augmented[:, :, 0], transformations
def forward_chunk(self, images):
batch_size, _, h, w = images.shape
# Perform augmentation
images_aug, transformations = self._get_augmentations(images)
# Get representations
features = self.gt_encoder(images_aug.flatten(0, 1))[-1]
features = F.interpolate(features, size=(h, w), mode="bilinear")
# features = features.repeat_interleave(self.feature_stride, -1).repeat_interleave(self.feature_stride, -2)
_, dino_dim, _, _ = features.shape
features = features.view(batch_size, -1, dino_dim, h, w)
chunks = torch.chunk(features, chunks=4, dim=2) # Split into 4 parts along dim=3
chunks = [self._accumulate_predictions(chunk, transformations) for chunk in chunks]
features_accumulated = torch.cat(chunks, dim=1)
# features_accumulated = self._accumulate_predictions(features, transformations)
return features_accumulated / torch.linalg.norm(features_accumulated, dim=1, keepdim=True)
def forward(self, images):
max_chunk = 16
aug_no_images = images.shape[0] * self.augmentations_per_sample
if aug_no_images > max_chunk:
no_chunks = aug_no_images // max_chunk
images = torch.chunk(images, no_chunks)
features = [self.forward_chunk(image) for image in images]
features = torch.cat(features, dim=0)
return [features]
else:
return [self.forward_chunk(images)]
class InterpolatedGT(nn.Module):
def __init__(self, arch: str, gt_encoder: nn.Module, image_size: Tuple[int, int]):
super().__init__()
self.upsampler = NoDecoder(image_size, arch, normalize_features=False)
self.gt_encoder = gt_encoder
def forward(self, x):
gt_patches = self.gt_encoder(x)
return self.upsampler(gt_patches)
def _get_affine(params, crop_size, batch_size):
# construct affine operator
affine = torch.zeros(batch_size, 2, 3)
aspect_ratio = float(crop_size[0]) / float(crop_size[1])
for i, (dy, dx, alpha, scale, flip) in enumerate(params):
# R inverse
sin = math.sin(alpha * math.pi / 180.)
cos = math.cos(alpha * math.pi / 180.)
# inverse, note how flipping is incorporated
affine[i, 0, 0], affine[i, 0, 1] = flip * cos, sin * aspect_ratio
affine[i, 1, 0], affine[i, 1, 1] = -sin / aspect_ratio, cos
# T inverse Rinv * t == R^T * t
affine[i, 0, 2] = -1. * (cos * dx + sin * dy)
affine[i, 1, 2] = -1. * (-sin * dx + cos * dy)
# T
affine[i, 0, 2] /= float(crop_size[1] // 2)
affine[i, 1, 2] /= float(crop_size[0] // 2)
# scaling
affine[i] *= scale
return affine
class MultiScaleCropGT(nn.Module):
def __init__(self,
gt_encoder: nn.Module,
num_views: int,
scale_from: float = 0.4,
grid_sample_batch: Optional[int] = 96):
super().__init__()
self.gt_encoder = gt_encoder
self.num_views = num_views
self.augmentation = MaskRandScaleCrop(scale_from)
self.grid_sample_batch = grid_sample_batch
def forward(self, x):
result = None
count = 0
batch_size, _, h, w = x.shape
for i in range(self.num_views):
if i > 0:
x, params = self.augmentation(x)
else:
params = [[0., 0., 0., 1., 1.] for _ in range(x.shape[0])]
gt_patches = self.gt_encoder(x)[-1]
affine = _get_affine(params, (h, w), batch_size).cuda()
affine_grid_gt = F.affine_grid(affine, x.size(), align_corners=False)
if self.grid_sample_batch:
d = gt_patches.shape[1]
assert d % self.grid_sample_batch == 0
for idx in range(0, d, self.grid_sample_batch):
gt_aligned_batch = F.grid_sample(gt_patches[:, idx:idx+self.grid_sample_batch], affine_grid_gt,
mode="bilinear", align_corners=False)
if result is None:
result = torch.zeros(batch_size, d, h, w, device="cuda")
result[:, idx:idx+self.grid_sample_batch] += gt_aligned_batch
else:
gt_aligned = F.grid_sample(gt_patches, affine_grid_gt, mode="bilinear", align_corners=False)
if result is None:
result = 0
result += gt_aligned
within_bounds_x = (affine_grid_gt[..., 0] >= -1) & (affine_grid_gt[..., 0] <= 1)
within_bounds_y = (affine_grid_gt[..., 1] >= -1) & (affine_grid_gt[..., 1] <= 1)
not_padded_mask = within_bounds_x & within_bounds_y
count += not_padded_mask.unsqueeze(1)
count[count == 0] = 1
return [result.div_(count)]
class MaskRandScaleCrop(object):
def __init__(self, scale_from):
self.scale_from = scale_from
def get_params(self, h, w):
new_scale = random.uniform(self.scale_from, 1)
new_h = int(new_scale * h)
new_w = int(new_scale * w)
i = random.randint(0, h - new_h)
j = random.randint(0, w - new_w)
flip = 1 if random.random() > 0.5 else -1
return i, j, new_h, new_w, new_scale, flip
def __call__(self, images, affine=None):
if affine is None:
affine = [[0., 0., 0., 1., 1.] for _ in range(len(images))]
_, H, W = images[0].shape
i2 = H / 2
j2 = W / 2
for k, image in enumerate(images):
ii, jj, h, w, s, flip = self.get_params(H, W)
if s == 1.:
continue # no change in scale
# displacement of the centre
dy = ii + h / 2 - i2
dx = jj + w / 2 - j2
affine[k][0] = dy
affine[k][1] = dx
affine[k][3] = 1 / s
# affine[k][4] = flip
assert ii >= 0 and jj >= 0
image_crop = tf.functional.crop(image, ii, jj, h, w)
images[k] = tf.functional.resize(image_crop, (H, W), tf.InterpolationMode.BILINEAR)
return images, affine