sdutta28's picture
HF Changes
32cc554
import torch
from torch import nn
class ConvADN(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel=2,
stride=2,
dilation=1,
padding=0,
p_drop=0.2,
is_transpose: bool = False,
):
super().__init__()
self.model = nn.Sequential(
(nn.Conv2d if not is_transpose else nn.ConvTranspose2d)(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel,
stride=stride,
dilation=dilation,
padding=padding,
),
nn.GELU(),
nn.Dropout(p_drop),
nn.InstanceNorm3d(num_features=out_channels),
)
def forward(self, x):
return self.model(x)
class Encoder(nn.Module):
def __init__(self, in_channels: int = 3):
super().__init__()
self.model = nn.Sequential(
ConvADN(in_channels, 32, kernel=2, stride=2, padding=0),
ConvADN(32, 64, kernel=2, stride=2, padding=0),
ConvADN(64, 128, kernel=2, stride=2, padding=0),
ConvADN(128, 256, kernel=2, stride=2, padding=0),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
class Decoder(nn.Module):
def __init__(self, out_channels: int = 3):
super().__init__()
self.model = nn.Sequential(
ConvADN(256, 128, kernel=2, stride=2, padding=0, is_transpose=True),
ConvADN(128, 64, kernel=2, stride=2, padding=0, is_transpose=True),
ConvADN(64, 32, kernel=2, stride=2, padding=0, is_transpose=True),
ConvADN(32, out_channels, kernel=2, stride=2, padding=0, is_transpose=True),
)
self.output = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.model(x)
return self.output(x)
class Autoencoder(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
) -> None:
super().__init__()
self.encoder = Encoder(in_channels)
self.decoder = Decoder(out_channels)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.encoder(x)
x = self.decoder(x)
return x