import torch from torch import nn from torch.nn import functional as F class Conv2d(nn.Module): def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): super().__init__(*args, **kwargs) self.conv_block = nn.Sequential( nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout) ) self.act = nn.ReLU() self.residual = residual def forward(self, x): out = self.conv_block(x) if self.residual: out += x return self.act(out) class nonorm_Conv2d(nn.Module): def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): super().__init__(*args, **kwargs) self.conv_block = nn.Sequential( nn.Conv2d(cin, cout, kernel_size, stride, padding), ) self.act = nn.LeakyReLU(0.01, inplace=True) def forward(self, x): out = self.conv_block(x) return self.act(out) class Conv2dTranspose(nn.Module): def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): super().__init__(*args, **kwargs) self.conv_block = nn.Sequential( nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), nn.BatchNorm2d(cout) ) self.act = nn.ReLU() def forward(self, x): out = self.conv_block(x) return self.act(out)