Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,756 Bytes
9e15541 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import torch.nn.functional as F
import torchvision
class BilinearDownsampler(torch.nn.Module):
def __init__(
self,
patch_size,
):
super().__init__()
if isinstance(patch_size, int):
self.patch_size = (patch_size, patch_size)
elif isinstance(patch_size, tuple):
self.patch_size = patch_size
def forward(self, x, mode):
n, v, h, w, _, c = x.shape
assert h % self.patch_size[0] == 0
target_h = h // self.patch_size[0]
assert w % self.patch_size[1] == 0
target_w = w // self.patch_size[1]
x = x.permute(0, 1, 4, 5, 2, 3).flatten(0, 2)
x = F.interpolate(x, size=(target_h, target_w), mode="bilinear")
x = x.reshape(n, v, -1, c, target_h, target_w).permute(0, 1, 4, 5, 2, 3)
return x.squeeze(2, 3)
class PatchSalienceDownsampler(torch.nn.Module):
def __init__(
self,
channels,
patch_size,
normalize_features,
):
super().__init__()
if isinstance(patch_size, int):
self.patch_size = (patch_size, patch_size)
elif isinstance(patch_size, tuple):
self.patch_size = patch_size
self.conv = torch.nn.Conv2d(channels, 1, kernel_size=1)
self.patch_weight = torch.nn.Parameter(torch.ones(self.patch_size))
self.patch_bias = torch.nn.Parameter(torch.zeros(self.patch_size))
self.normalize_features = normalize_features
torch.nn.init.kaiming_normal_(self.conv.weight, a=0, mode="fan_in")
torch.nn.init.zeros_(self.conv.bias)
torch.nn.init.normal_(self.patch_weight, mean=1.0, std=0.01)
torch.nn.init.normal_(self.patch_bias, mean=0.0, std=0.01)
def forward(self, x, mode):
if mode == "patch":
return self.forward_patches(x)
elif mode == "image":
n, v, h, w, _, c = x.shape
patch_h, patch_w = self.patch_size[0], self.patch_size[1]
no_patches_h, no_patches_w = h // patch_h, w // patch_w
patches = x.reshape(n, v, no_patches_h, patch_h, no_patches_w, patch_w, 1, c)
patches = patches.swapaxes(3, 4).flatten(1, 3)
patched_result, salience_map, weight_map, patch_weight_bias = self.forward_patches(patches)
patched_result = patched_result.reshape(n, v, no_patches_h, no_patches_w, 1, c)
salience_map = salience_map.reshape(n, v, no_patches_h, no_patches_w, patch_h, patch_w, 1, 1)
salience_map = salience_map.swapaxes(3, 4).reshape(n, v, h, w, 1, 1)
weight_map = weight_map.reshape(n, v, no_patches_h, no_patches_w, patch_h, patch_w, 1, 1)
weight_map = weight_map.swapaxes(3, 4).reshape(n, v, h, w, 1, 1)
return patched_result, salience_map, weight_map, patch_weight_bias
else:
return None
def forward_patches(self, x):
n, p, patch_h, patch_w, _, c = x.shape
x_flat = x.reshape(-1, patch_h, patch_w, c).permute(0, 3, 1, 2)
salience_map = self.conv(x_flat).squeeze(1)
weight_map = salience_map * self.patch_weight + self.patch_bias
weight_map = torch.nn.functional.softmax(weight_map.reshape(-1, patch_h * patch_w), dim=1)
weight_map = weight_map.reshape(n, p, patch_h, patch_w, 1, 1)
patched_features = torch.sum(weight_map * x, dim=(2, 3))
if self.normalize_features:
patched_features = patched_features / torch.linalg.norm(patched_features, dim=-1, keepdim=True)
return (patched_features,
salience_map.reshape(n, p, patch_h, patch_w, 1, 1),
weight_map,
torch.cat([self.patch_weight, self.patch_bias], dim=1))
|