from torch import nn import pdb import torch # to use with clip class Discriminator64(nn.Module): def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False): super(Discriminator64, self).__init__() self.bnorm = bnorm self.generic = generic self.relu = nn.LeakyReLU(leakyparam, inplace=True) self.bn2 = nn.BatchNorm2d(128) self.bn3 = nn.BatchNorm2d(256) self.bn4 = nn.BatchNorm2d(512) self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias) self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias) self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias) self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias) if generic: self.layer5 = nn.Conv2d(512, 26, 4, 1, 0, bias=bias) else: self.layer5 = nn.Conv2d(512, 1, 4, 1, 0, bias=bias) self.sig = nn.Sigmoid() def forward(self, input, letter): out1 = self.relu(self.layer1(input)) if self.bnorm: out2 = self.relu(self.bn2(self.layer2(out1))) out3 = self.relu(self.bn3(self.layer3(out2))) out4= self.relu(self.bn4(self.layer4(out3))) else: out2 = self.relu(self.layer2(out1)) out3 = self.relu(self.layer3(out2)) out4= self.relu(self.layer4(out3)) out5 = self.sig(self.layer5(out4)) out5 = out5.flatten() if self.generic: out5 = out5[letter].mean() else: out5 = out5.mean() return out5 # to use with bert class Discriminator(nn.Module): def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False): super(Discriminator, self).__init__() self.bnorm = bnorm self.generic = generic self.relu = nn.LeakyReLU(leakyparam, inplace=True) self.sig = nn.Sigmoid() self.bn2 = nn.BatchNorm2d(128) self.bn3 = nn.BatchNorm2d(256) self.bn4 = nn.BatchNorm2d(512) self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias) self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias) self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias) self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias) if generic: self.layer5 = nn.Conv2d(512, 26, 2, 1, 0, bias=bias) else: self.layer5 = nn.Conv2d(512, 1, 2, 1, 0, bias=bias) def forward(self, input, letter): out1 = self.relu(self.layer1(input)) if self.bnorm: out2 = self.relu(self.bn2(self.layer2(out1))) out3 = self.relu(self.bn3(self.layer3(out2))) out4= self.relu(self.bn4(self.layer4(out3))) else: out2 = self.relu(self.layer2(out1)) out3 = self.relu(self.layer3(out2)) out4= self.relu(self.layer4(out3)) out5 = self.sig(self.layer5(out4)) out5 = out5.flatten() if self.generic: out5 = out5[letter].mean() else: out5 = out5.mean() return out5