DS-Fusion-Express / ldm /modules /discriminator.py
mta122
first
ca4133a
raw
history blame
No virus
3.07 kB
from torch import nn
import pdb
import torch
# to use with clip
class Discriminator64(nn.Module):
def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False):
super(Discriminator64, self).__init__()
self.bnorm = bnorm
self.generic = generic
self.relu = nn.LeakyReLU(leakyparam, inplace=True)
self.bn2 = nn.BatchNorm2d(128)
self.bn3 = nn.BatchNorm2d(256)
self.bn4 = nn.BatchNorm2d(512)
self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias)
self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias)
self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias)
self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias)
if generic:
self.layer5 = nn.Conv2d(512, 26, 4, 1, 0, bias=bias)
else:
self.layer5 = nn.Conv2d(512, 1, 4, 1, 0, bias=bias)
self.sig = nn.Sigmoid()
def forward(self, input, letter):
out1 = self.relu(self.layer1(input))
if self.bnorm:
out2 = self.relu(self.bn2(self.layer2(out1)))
out3 = self.relu(self.bn3(self.layer3(out2)))
out4= self.relu(self.bn4(self.layer4(out3)))
else:
out2 = self.relu(self.layer2(out1))
out3 = self.relu(self.layer3(out2))
out4= self.relu(self.layer4(out3))
out5 = self.sig(self.layer5(out4))
out5 = out5.flatten()
if self.generic:
out5 = out5[letter].mean()
else:
out5 = out5.mean()
return out5
# to use with bert
class Discriminator(nn.Module):
def __init__(self, bnorm=True, leakyparam=0.0, bias=False, generic=False):
super(Discriminator, self).__init__()
self.bnorm = bnorm
self.generic = generic
self.relu = nn.LeakyReLU(leakyparam, inplace=True)
self.sig = nn.Sigmoid()
self.bn2 = nn.BatchNorm2d(128)
self.bn3 = nn.BatchNorm2d(256)
self.bn4 = nn.BatchNorm2d(512)
self.layer1 = nn.Conv2d(4, 64, 4, 2, 1, bias=bias)
self.layer2 = nn.Conv2d(64, 128, 4, 2, 1, bias=bias)
self.layer3 = nn.Conv2d(128, 256, 4, 2, 1, bias=bias)
self.layer4 = nn.Conv2d(256, 512, 4, 2, 1, bias=bias)
if generic:
self.layer5 = nn.Conv2d(512, 26, 2, 1, 0, bias=bias)
else:
self.layer5 = nn.Conv2d(512, 1, 2, 1, 0, bias=bias)
def forward(self, input, letter):
out1 = self.relu(self.layer1(input))
if self.bnorm:
out2 = self.relu(self.bn2(self.layer2(out1)))
out3 = self.relu(self.bn3(self.layer3(out2)))
out4= self.relu(self.bn4(self.layer4(out3)))
else:
out2 = self.relu(self.layer2(out1))
out3 = self.relu(self.layer3(out2))
out4= self.relu(self.layer4(out3))
out5 = self.sig(self.layer5(out4))
out5 = out5.flatten()
if self.generic:
out5 = out5[letter].mean()
else:
out5 = out5.mean()
return out5