File size: 3,045 Bytes
f9567e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
import torch.nn.functional as F
from torch import nn
from timm.models.layers import trunc_normal_
class Linear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
class LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
class Conv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
class Embedding(nn.Embedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
class ImageNorm(nn.Module):
def forward(self, x):
assert x.dim() == 4
eps = 1e-05
x = x / (x.var(dim = (1, 2, 3), keepdim = True) + eps).sqrt()
return x
class Flatten(nn.Module):
def forward(self, x):
B, H, W, C = x.shape
x = x.reshape(B, H * W, C)
return x
class ChannelLast(nn.Module):
def forward(self, x):
assert x.dim() == 4
x = x.permute(0, 2, 3, 1) # [B, H, W, C]
return x
class ChannelFirst(nn.Module):
def forward(self, x):
assert x.dim() == 4
x = x.permute(0, 3, 1, 2) # [B, C, H, W]
return x
class OddUpInterpolate(nn.Module):
def __init__(self, ratio):
super().__init__()
self.ratio = ratio
def forward(self, x):
if self.ratio == 1:
return x
assert x.dim() == 4
B, C, H, W = x.shape
x = F.interpolate(x, size = ((H - 1) * self.ratio + 1, (W - 1) * self.ratio + 1), mode = "bilinear", align_corners = True)
return x
def __repr__(self):
return f"UpInterpolate(ratio={self.ratio})"
class OddDownInterpolate(nn.Module):
def __init__(self, ratio):
super().__init__()
self.ratio = ratio
def forward(self, x):
if self.ratio == 1:
return x
assert x.dim() == 4
B, C, H, W = x.shape
x = F.interpolate(x, size = ((H - 1) // self.ratio + 1, (W - 1) // self.ratio + 1), mode = "area")
return x
def __repr__(self):
return f"DownInterpolate(ratio={self.ratio})"
class EvenDownInterpolate(nn.Module):
def __init__(self, ratio):
super().__init__()
self.ratio = ratio
def forward(self, x):
if self.ratio == 1:
return x
assert len(x.shape) == 4
B, C, H, W = x.shape
x = F.interpolate(x, size = (H // self.ratio, W // self.ratio), mode = "area")
return x
def __repr__(self):
return f"DownInterpolate(ratio={self.ratio})" |