CrossFlow / libs /model /common_layers.py
QHL067's picture
working
f9567e5
import torch
import torch.nn.functional as F
from torch import nn
from timm.models.layers import trunc_normal_
class Linear(nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
class LayerNorm(nn.LayerNorm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
class Conv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
class Embedding(nn.Embedding):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
trunc_normal_(self.weight, mean = 0, std = 0.02)
class ImageNorm(nn.Module):
def forward(self, x):
assert x.dim() == 4
eps = 1e-05
x = x / (x.var(dim = (1, 2, 3), keepdim = True) + eps).sqrt()
return x
class Flatten(nn.Module):
def forward(self, x):
B, H, W, C = x.shape
x = x.reshape(B, H * W, C)
return x
class ChannelLast(nn.Module):
def forward(self, x):
assert x.dim() == 4
x = x.permute(0, 2, 3, 1) # [B, H, W, C]
return x
class ChannelFirst(nn.Module):
def forward(self, x):
assert x.dim() == 4
x = x.permute(0, 3, 1, 2) # [B, C, H, W]
return x
class OddUpInterpolate(nn.Module):
def __init__(self, ratio):
super().__init__()
self.ratio = ratio
def forward(self, x):
if self.ratio == 1:
return x
assert x.dim() == 4
B, C, H, W = x.shape
x = F.interpolate(x, size = ((H - 1) * self.ratio + 1, (W - 1) * self.ratio + 1), mode = "bilinear", align_corners = True)
return x
def __repr__(self):
return f"UpInterpolate(ratio={self.ratio})"
class OddDownInterpolate(nn.Module):
def __init__(self, ratio):
super().__init__()
self.ratio = ratio
def forward(self, x):
if self.ratio == 1:
return x
assert x.dim() == 4
B, C, H, W = x.shape
x = F.interpolate(x, size = ((H - 1) // self.ratio + 1, (W - 1) // self.ratio + 1), mode = "area")
return x
def __repr__(self):
return f"DownInterpolate(ratio={self.ratio})"
class EvenDownInterpolate(nn.Module):
def __init__(self, ratio):
super().__init__()
self.ratio = ratio
def forward(self, x):
if self.ratio == 1:
return x
assert len(x.shape) == 4
B, C, H, W = x.shape
x = F.interpolate(x, size = (H // self.ratio, W // self.ratio), mode = "area")
return x
def __repr__(self):
return f"DownInterpolate(ratio={self.ratio})"