Spaces:
Runtime error
Runtime error
from torch import nn | |
import pdb | |
import torch | |
# to use with clip | |
class Discriminator64(nn.Module): | |
def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False): | |
super(Discriminator64, self).__init__() | |
self.bnorm = bnorm | |
self.generic = generic | |
self.relu = nn.LeakyReLU(leakyparam, inplace=True) | |
self.bn2 = nn.BatchNorm2d(128) | |
self.bn3 = nn.BatchNorm2d(256) | |
self.bn4 = nn.BatchNorm2d(512) | |
self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias) | |
self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias) | |
self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias) | |
self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias) | |
if generic: | |
self.layer5 = nn.Conv2d(512, 26, 4, 1, 0, bias=bias) | |
else: | |
self.layer5 = nn.Conv2d(512, 1, 4, 1, 0, bias=bias) | |
self.sig = nn.Sigmoid() | |
def forward(self, input, letter): | |
out1 = self.relu(self.layer1(input)) | |
if self.bnorm: | |
out2 = self.relu(self.bn2(self.layer2(out1))) | |
out3 = self.relu(self.bn3(self.layer3(out2))) | |
out4= self.relu(self.bn4(self.layer4(out3))) | |
else: | |
out2 = self.relu(self.layer2(out1)) | |
out3 = self.relu(self.layer3(out2)) | |
out4= self.relu(self.layer4(out3)) | |
out5 = self.sig(self.layer5(out4)) | |
out5 = out5.flatten() | |
if self.generic: | |
out5 = out5[letter].mean() | |
else: | |
out5 = out5.mean() | |
return out5 | |
# to use with bert | |
class Discriminator(nn.Module): | |
def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False): | |
super(Discriminator, self).__init__() | |
self.bnorm = bnorm | |
self.generic = generic | |
self.relu = nn.LeakyReLU(leakyparam, inplace=True) | |
self.sig = nn.Sigmoid() | |
self.bn2 = nn.BatchNorm2d(128) | |
self.bn3 = nn.BatchNorm2d(256) | |
self.bn4 = nn.BatchNorm2d(512) | |
self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias) | |
self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias) | |
self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias) | |
self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias) | |
if generic: | |
self.layer5 = nn.Conv2d(512, 26, 2, 1, 0, bias=bias) | |
else: | |
self.layer5 = nn.Conv2d(512, 1, 2, 1, 0, bias=bias) | |
def forward(self, input, letter): | |
out1 = self.relu(self.layer1(input)) | |
if self.bnorm: | |
out2 = self.relu(self.bn2(self.layer2(out1))) | |
out3 = self.relu(self.bn3(self.layer3(out2))) | |
out4= self.relu(self.bn4(self.layer4(out3))) | |
else: | |
out2 = self.relu(self.layer2(out1)) | |
out3 = self.relu(self.layer3(out2)) | |
out4= self.relu(self.layer4(out3)) | |
out5 = self.sig(self.layer5(out4)) | |
out5 = out5.flatten() | |
if self.generic: | |
out5 = out5[letter].mean() | |
else: | |
out5 = out5.mean() | |
return out5 | |