CLIP_decoder / models.py
Kevin Li
Upload folder using huggingface_hub
a13d3e2 verified
import torch
import torch.nn as nn
class Decoder(nn.Module):
def __init__(self, input_dim, hidden_dim, gamma=0.1):
super().__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.gamma = gamma
self.float()
#should be 512, 1024
self.fc = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim * 2),
nn.BatchNorm1d(hidden_dim * 2),
nn.ReLU(),
nn.Linear(hidden_dim * 2, hidden_dim * 4),
nn.BatchNorm1d(hidden_dim * 4),
nn.ReLU(),
nn.Linear(hidden_dim * 4, hidden_dim * 8),
nn.BatchNorm1d(hidden_dim * 8),
nn.ReLU(),
nn.Linear(hidden_dim * 8, hidden_dim * 4 * 4),
nn.BatchNorm1d(hidden_dim * 4 * 4),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(1024, 768, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(768),
nn.ReLU(),
nn.ConvTranspose2d(768, 512, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 3, kernel_size=3, padding=1),
nn.Sigmoid()
)
def forward(self, z):
batch_size = z.shape[0]
# adding noise to inputs
gamma = 0.05
z = z + self.gamma * torch.randn_like(z)
z = self.fc(z)
z = z.view(batch_size, 1024, 4, 4)
return self.decoder(z)
def get_loss(self, emb, x):
x_hat = self.forward(emb)
l = nn.MSELoss(reduction="mean")
loss = l(x_hat, x)
return loss
@torch.no_grad()
def sample(self, samples, device):
samples = samples.to(device)
x_hat = self.forward(samples)
return x_hat