lino / src /models /utils /decompose_tensors.py
algohunt
initial_commit
c295391
"""
Scalable, Detailed and Mask-free Universal Photometric Stereo Network (CVPR2023)
# Copyright (c) 2023 Satoshi Ikehata
# All rights reserved.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def divide_tensor_spatial(x, block_size=256, method='tile_stride'):
assert x.dim() == 4, "Input tensor must have 4 dimensions [B, C, H, W]"
B, C, H, W = x.shape
assert H == W, "Height and Width must be equal"
assert H % block_size == 0 and W % block_size ==0, "The tensor size cannot be divided by the block size"
mosaic_scale = H // block_size
if method == 'tile_stride':
""" decomposing x into K x K of (Hc, Wc) non-overlapped blocks (grid)"""
K = mosaic_scale * mosaic_scale
fold_params_grid = dict(kernel_size=(mosaic_scale, mosaic_scale), stride=(mosaic_scale, mosaic_scale), padding=(0,0), dilation=(1,1))
unfold_grid = nn.Unfold(**fold_params_grid)
tensor_grids = unfold_grid(x) # (B, C * K, Hm * Hm)
tensor_grids = tensor_grids.reshape(B, C, K, block_size, block_size).permute(0, 2, 1, 3, 4) # (B, K, C, Hm, Hm)
return tensor_grids
if method == 'tile_block':
tensor_blocks = x.view(B, C, mosaic_scale, block_size, mosaic_scale, block_size)
tensor_blocks = tensor_blocks.permute(0, 2, 4, 1, 3, 5) # (B, mc, mc, C, Hm, Wm)
tensor_blocks = tensor_blocks.contiguous().view(B, mosaic_scale**2, C, block_size, block_size) ## (B, K, C, Hm, Hm)
return tensor_blocks
return -1
def merge_tensor_spatial(x, method='tile_stride'):
K, N, feat_dim, Hm, Wm = x.shape
mosaic_scale = int(math.sqrt(K))
if method == 'tile_stride':
x = x.reshape(K, N, feat_dim, -1)
fold_params_grid = dict(kernel_size=(mosaic_scale, mosaic_scale), stride=(mosaic_scale, mosaic_scale), padding=(0,0), dilation=(1,1))
fold_grid = nn.Fold(output_size=(Hm * mosaic_scale, Wm * mosaic_scale), **fold_params_grid) # downsample based on the encoder
x = x.permute(1, 2, 0, 3).reshape(N, feat_dim * K, -1)
x = fold_grid(x)
return x
if method == 'tile_block':
x = x.permute(1, 0, 2, 3, 4).reshape(N, mosaic_scale, mosaic_scale, feat_dim, Hm, Wm)
x = x.permute(0, 3, 1, 4, 2, 5)
x = x.reshape(N, feat_dim, mosaic_scale * Hm, mosaic_scale * Wm)
return x
def divide_overlapping_patches(input_tensor, patch_size, margin):
B, C, W, _ = input_tensor.shape
stride = patch_size - margin
padded_W = ((W - patch_size + stride - 1) // stride) * stride + patch_size
pad = padded_W - W
padded_tensor = F.pad(input_tensor, (0, pad, 0, pad), mode='constant', value=0)
patches = F.unfold(padded_tensor, kernel_size=patch_size, stride=stride)
patches = patches.view(B, C, patch_size, patch_size, -1).permute(0, 4, 1, 2, 3)
return patches
def merge_overlappnig_patches(patches, patch_size, margin, original_size):
B, _, C, _, _ = patches.shape
stride = patch_size - margin
W = original_size[2]
patches = patches.permute(0, 2, 3, 4, 1).contiguous().view(B, C * patch_size * patch_size, -1)
output = F.fold(patches, (W, W), kernel_size=patch_size, stride=stride)
weight = torch.ones(patches.size()).to(patches.device)
weight = F.fold(weight, (W, W), kernel_size=patch_size, stride=stride)
output = output / weight
return output