Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Scalable, Detailed and Mask-free Universal Photometric Stereo Network (CVPR2023) | |
# Copyright (c) 2023 Satoshi Ikehata | |
# All rights reserved. | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
def divide_tensor_spatial(x, block_size=256, method='tile_stride'): | |
assert x.dim() == 4, "Input tensor must have 4 dimensions [B, C, H, W]" | |
B, C, H, W = x.shape | |
assert H == W, "Height and Width must be equal" | |
assert H % block_size == 0 and W % block_size ==0, "The tensor size cannot be divided by the block size" | |
mosaic_scale = H // block_size | |
if method == 'tile_stride': | |
""" decomposing x into K x K of (Hc, Wc) non-overlapped blocks (grid)""" | |
K = mosaic_scale * mosaic_scale | |
fold_params_grid = dict(kernel_size=(mosaic_scale, mosaic_scale), stride=(mosaic_scale, mosaic_scale), padding=(0,0), dilation=(1,1)) | |
unfold_grid = nn.Unfold(**fold_params_grid) | |
tensor_grids = unfold_grid(x) # (B, C * K, Hm * Hm) | |
tensor_grids = tensor_grids.reshape(B, C, K, block_size, block_size).permute(0, 2, 1, 3, 4) # (B, K, C, Hm, Hm) | |
return tensor_grids | |
if method == 'tile_block': | |
tensor_blocks = x.view(B, C, mosaic_scale, block_size, mosaic_scale, block_size) | |
tensor_blocks = tensor_blocks.permute(0, 2, 4, 1, 3, 5) # (B, mc, mc, C, Hm, Wm) | |
tensor_blocks = tensor_blocks.contiguous().view(B, mosaic_scale**2, C, block_size, block_size) ## (B, K, C, Hm, Hm) | |
return tensor_blocks | |
return -1 | |
def merge_tensor_spatial(x, method='tile_stride'): | |
K, N, feat_dim, Hm, Wm = x.shape | |
mosaic_scale = int(math.sqrt(K)) | |
if method == 'tile_stride': | |
x = x.reshape(K, N, feat_dim, -1) | |
fold_params_grid = dict(kernel_size=(mosaic_scale, mosaic_scale), stride=(mosaic_scale, mosaic_scale), padding=(0,0), dilation=(1,1)) | |
fold_grid = nn.Fold(output_size=(Hm * mosaic_scale, Wm * mosaic_scale), **fold_params_grid) # downsample based on the encoder | |
x = x.permute(1, 2, 0, 3).reshape(N, feat_dim * K, -1) | |
x = fold_grid(x) | |
return x | |
if method == 'tile_block': | |
x = x.permute(1, 0, 2, 3, 4).reshape(N, mosaic_scale, mosaic_scale, feat_dim, Hm, Wm) | |
x = x.permute(0, 3, 1, 4, 2, 5) | |
x = x.reshape(N, feat_dim, mosaic_scale * Hm, mosaic_scale * Wm) | |
return x | |
def divide_overlapping_patches(input_tensor, patch_size, margin): | |
B, C, W, _ = input_tensor.shape | |
stride = patch_size - margin | |
padded_W = ((W - patch_size + stride - 1) // stride) * stride + patch_size | |
pad = padded_W - W | |
padded_tensor = F.pad(input_tensor, (0, pad, 0, pad), mode='constant', value=0) | |
patches = F.unfold(padded_tensor, kernel_size=patch_size, stride=stride) | |
patches = patches.view(B, C, patch_size, patch_size, -1).permute(0, 4, 1, 2, 3) | |
return patches | |
def merge_overlappnig_patches(patches, patch_size, margin, original_size): | |
B, _, C, _, _ = patches.shape | |
stride = patch_size - margin | |
W = original_size[2] | |
patches = patches.permute(0, 2, 3, 4, 1).contiguous().view(B, C * patch_size * patch_size, -1) | |
output = F.fold(patches, (W, W), kernel_size=patch_size, stride=stride) | |
weight = torch.ones(patches.size()).to(patches.device) | |
weight = F.fold(weight, (W, W), kernel_size=patch_size, stride=stride) | |
output = output / weight | |
return output | |