Spaces:
Runtime error
Runtime error
import glob | |
import numpy as np | |
import pickle | |
import torch | |
from torch import nn | |
from torch.utils.data import DataLoader, Dataset | |
import torchvision.transforms as T | |
from torchvision.utils import make_grid | |
img_path = '/home/alan/Projects/gen_dnd_art/filtered_images/im128/*pkl' | |
img_files = glob.glob(img_path) | |
# determine class names | |
labels = np.array([i.split('/')[-1].split('_')[:3] for i in img_files]) | |
species = np.unique(labels[:, 0]).tolist() | |
classes = np.unique(labels[:, 1]).tolist() | |
genders = np.unique(labels[:, 2]).tolist() | |
class ImSet(Dataset): | |
def __init__(self, img_path=img_path): | |
super().__init__() | |
self.img_files = glob.glob(img_path) | |
self.transform = T.Compose([ | |
T.ToTensor(), | |
T.ColorJitter(0.1, 0.1, 0.1, 0.1), | |
T.RandomHorizontalFlip(), | |
# add random noise and clip | |
lambda x: torch.clip(torch.randn(x.shape) / 20 + x, 0, 1), | |
T.Normalize(0.5, 0.5) | |
]) | |
def __len__(self): | |
return len(self.img_files) | |
def __getitem__(self, i): | |
img_file = self.img_files[i] | |
# load image | |
with open(img_file, 'rb') as fid: | |
img = pickle.load(fid) | |
# apply transforms | |
img = self.transform(img).float() | |
# extract class label | |
img_fname = img_file.split('/')[-1] | |
species_, class_, gender_, _, _ = img_fname.split('_') | |
species_ = species.index(species_) | |
class_ = classes.index(class_) | |
gender_ = genders.index(gender_) | |
return (img_fname, img, species_, class_, gender_) | |
class VariationalEncoder(nn.Module): | |
def __init__(self, input_channels=3, latent_size=2048): | |
super().__init__() | |
self.latent_size = latent_size | |
self.net = nn.Sequential( | |
# 128 -> 63 | |
nn.Conv2d(input_channels, 8, 4, 2), | |
nn.LeakyReLU(0.2), | |
# 63 -> 31 | |
nn.Conv2d(8, 16, 3, 2), | |
nn.LeakyReLU(0.2), | |
# 31 -> 15 | |
nn.Conv2d(16, 32, 3, 2), | |
nn.LeakyReLU(0.2), | |
# 15 -> 7 | |
nn.Conv2d(32, 64, 3, 2), | |
nn.LeakyReLU(0.2), | |
# 7 -> 5 | |
nn.Conv2d(64, 128, 3, 1), | |
nn.LeakyReLU(0.2), | |
# 5 -> 4 | |
nn.Conv2d(128, 256, 2, 1), | |
nn.LeakyReLU(0.2), | |
# 4 -> 3 | |
nn.Conv2d(256, 512, 2, 1), | |
nn.LeakyReLU(0.2), | |
# 3 -> 2 | |
nn.Conv2d(512, 1024, 2, 1), | |
nn.LeakyReLU(0.2), | |
# 2 -> 1 | |
nn.Conv2d(1024, latent_size, 2, 1), | |
nn.LeakyReLU(0.2), | |
nn.Flatten(), | |
nn.Linear(latent_size, latent_size), | |
nn.Dropout(0.4) | |
) | |
# parameters for variational autoencoder | |
self.mu = nn.Linear(latent_size, latent_size) | |
self.sigma = nn.Linear(latent_size, latent_size) | |
self.N = torch.distributions.Normal(0, 1) | |
self.N.loc = self.N.loc.cuda() | |
self.N.scale = self.N.scale.cuda() | |
self.kl = 0 | |
def forward(self, x): | |
x = self.net(x) | |
mu = self.mu(x) | |
sigma = torch.exp(self.sigma(x)) | |
x = mu + sigma * self.N.sample(mu.shape) | |
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() | |
return x | |
class ConditionalEncoder(VariationalEncoder): | |
def __init__(self, latent_size=2048): | |
super().__init__(input_channels=4, latent_size=latent_size) | |
self.emb_species = nn.Embedding(len(species), 128**2 // 3 + 128**2 % 3) | |
self.emb_class = nn.Embedding(len(classes), 128**2 // 3) | |
self.emb_gender = nn.Embedding(len(genders), 128**2 // 3) | |
self.emb_reshape = nn.Unflatten(1, (1, 128, 128)) | |
def forward(self, img, species_, class_, gender_): | |
x = self.emb_species(species_) | |
y = self.emb_class(class_) | |
z = self.emb_gender(gender_) | |
x = torch.concat([x, y, z], dim=1) | |
x = self.emb_reshape(x) | |
x = torch.concat([img, x], dim=1) | |
x = self.net(x) | |
mu = self.mu(x) | |
sigma = torch.exp(self.sigma(x)) | |
x = mu + sigma * self.N.sample(mu.shape) | |
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum() | |
return x | |
class Decoder(nn.Module): | |
def __init__(self, latent_size=2048): | |
super().__init__() | |
self.latent_size = latent_size | |
self.net = nn.Sequential( | |
nn.Linear(latent_size, latent_size), | |
nn.Dropout(0.4), | |
nn.Unflatten(1, (latent_size, 1, 1)), | |
# 1 -> 2 | |
nn.ConvTranspose2d(latent_size, 1024, 2, 1), | |
nn.LeakyReLU(0.2), | |
# 2 -> 3 | |
nn.ConvTranspose2d(1024, 512, 2, 1), | |
nn.LeakyReLU(0.2), | |
# 3 -> 4 | |
nn.ConvTranspose2d(512, 256, 2, 1), | |
nn.LeakyReLU(0.2), | |
# 4 -> 5 | |
nn.ConvTranspose2d(256, 128, 2, 1), | |
nn.LeakyReLU(0.2), | |
# 5 -> 7 | |
nn.ConvTranspose2d(128, 64, 3, 1), | |
nn.LeakyReLU(0.2), | |
# 7 -> 15 | |
nn.ConvTranspose2d(64, 32, 3, 2), | |
nn.LeakyReLU(0.2), | |
# 15 -> 31 | |
nn.ConvTranspose2d(32, 16, 3, 2), | |
nn.LeakyReLU(0.2), | |
# 31 -> 63 | |
nn.ConvTranspose2d(16, 8, 3, 2), | |
nn.LeakyReLU(0.2), | |
# 63 -> 128 | |
nn.ConvTranspose2d(8, 3, 4, 2), | |
nn.Tanh() | |
) | |
def forward(self, x): | |
return self.net(x) | |
class ConditionalDecoder(Decoder): | |
def __init__(self, latent_size=1024): | |
super().__init__(latent_size) | |
self.emb_species = nn.Embedding(len(species), latent_size // 3 + latent_size % 3) | |
self.emb_class = nn.Embedding(len(classes), latent_size // 3) | |
self.emb_gender = nn.Embedding(len(genders), latent_size // 3) | |
self.label_net = nn.Linear(2*latent_size, latent_size) | |
def forward(self, Z, species_, class_, gender_): | |
x = self.emb_species(species_) | |
y = self.emb_class(class_) | |
z = self.emb_gender(gender_) | |
x = torch.concat([Z, x, y, z], dim=1) | |
x = self.label_net(x) | |
x = self.net(x) | |
return x | |
class VariationalAutoEncoder(nn.Module): | |
def __init__(self, latent_size=1024): | |
super().__init__() | |
self.latent_size = latent_size | |
self.enc = VariationalEncoder(latent_size) | |
self.dec = Decoder(latent_size) | |
def forward(self, x): | |
return self.dec(self.enc(x)) | |
class ConditionalVariationalAutoEncoder(nn.Module): | |
def __init__(self, latent_size=1024): | |
super().__init__() | |
self.latent_size = latent_size | |
self.enc = ConditionalEncoder(latent_size) | |
self.dec = ConditionalDecoder(latent_size) | |
def forward(self, img, species_, class_, gender_): | |
Z = self.enc(img, species_, class_, gender_) | |
x = self.dec(Z, species_, class_, gender_) | |
return x | |
def show_tensor(Z, ax, **kwargs): | |
if len(Z.shape) > 3: | |
Z = Z[0] | |
if Z.min() < 1: | |
Z = (Z + 1) / 2 | |
Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0)) | |
ax.imshow(Z, **kwargs) | |
return ax |