misterbrainley's picture
added utisl
b28391f
raw
history blame
7.54 kB
import glob
import numpy as np
import pickle
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.utils import make_grid
img_path = '/home/alan/Projects/gen_dnd_art/filtered_images/im128/*pkl'
img_files = glob.glob(img_path)
# determine class names
labels = np.array([i.split('/')[-1].split('_')[:3] for i in img_files])
species = np.unique(labels[:, 0]).tolist()
classes = np.unique(labels[:, 1]).tolist()
genders = np.unique(labels[:, 2]).tolist()
class ImSet(Dataset):
def __init__(self, img_path=img_path):
super().__init__()
self.img_files = glob.glob(img_path)
self.transform = T.Compose([
T.ToTensor(),
T.ColorJitter(0.1, 0.1, 0.1, 0.1),
T.RandomHorizontalFlip(),
# add random noise and clip
lambda x: torch.clip(torch.randn(x.shape) / 20 + x, 0, 1),
T.Normalize(0.5, 0.5)
])
def __len__(self):
return len(self.img_files)
def __getitem__(self, i):
img_file = self.img_files[i]
# load image
with open(img_file, 'rb') as fid:
img = pickle.load(fid)
# apply transforms
img = self.transform(img).float()
# extract class label
img_fname = img_file.split('/')[-1]
species_, class_, gender_, _, _ = img_fname.split('_')
species_ = species.index(species_)
class_ = classes.index(class_)
gender_ = genders.index(gender_)
return (img_fname, img, species_, class_, gender_)
class VariationalEncoder(nn.Module):
def __init__(self, input_channels=3, latent_size=2048):
super().__init__()
self.latent_size = latent_size
self.net = nn.Sequential(
# 128 -> 63
nn.Conv2d(input_channels, 8, 4, 2),
nn.LeakyReLU(0.2),
# 63 -> 31
nn.Conv2d(8, 16, 3, 2),
nn.LeakyReLU(0.2),
# 31 -> 15
nn.Conv2d(16, 32, 3, 2),
nn.LeakyReLU(0.2),
# 15 -> 7
nn.Conv2d(32, 64, 3, 2),
nn.LeakyReLU(0.2),
# 7 -> 5
nn.Conv2d(64, 128, 3, 1),
nn.LeakyReLU(0.2),
# 5 -> 4
nn.Conv2d(128, 256, 2, 1),
nn.LeakyReLU(0.2),
# 4 -> 3
nn.Conv2d(256, 512, 2, 1),
nn.LeakyReLU(0.2),
# 3 -> 2
nn.Conv2d(512, 1024, 2, 1),
nn.LeakyReLU(0.2),
# 2 -> 1
nn.Conv2d(1024, latent_size, 2, 1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(latent_size, latent_size),
nn.Dropout(0.4)
)
# parameters for variational autoencoder
self.mu = nn.Linear(latent_size, latent_size)
self.sigma = nn.Linear(latent_size, latent_size)
self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda()
self.N.scale = self.N.scale.cuda()
self.kl = 0
def forward(self, x):
x = self.net(x)
mu = self.mu(x)
sigma = torch.exp(self.sigma(x))
x = mu + sigma * self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return x
class ConditionalEncoder(VariationalEncoder):
def __init__(self, latent_size=2048):
super().__init__(input_channels=4, latent_size=latent_size)
self.emb_species = nn.Embedding(len(species), 128**2 // 3 + 128**2 % 3)
self.emb_class = nn.Embedding(len(classes), 128**2 // 3)
self.emb_gender = nn.Embedding(len(genders), 128**2 // 3)
self.emb_reshape = nn.Unflatten(1, (1, 128, 128))
def forward(self, img, species_, class_, gender_):
x = self.emb_species(species_)
y = self.emb_class(class_)
z = self.emb_gender(gender_)
x = torch.concat([x, y, z], dim=1)
x = self.emb_reshape(x)
x = torch.concat([img, x], dim=1)
x = self.net(x)
mu = self.mu(x)
sigma = torch.exp(self.sigma(x))
x = mu + sigma * self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return x
class Decoder(nn.Module):
def __init__(self, latent_size=2048):
super().__init__()
self.latent_size = latent_size
self.net = nn.Sequential(
nn.Linear(latent_size, latent_size),
nn.Dropout(0.4),
nn.Unflatten(1, (latent_size, 1, 1)),
# 1 -> 2
nn.ConvTranspose2d(latent_size, 1024, 2, 1),
nn.LeakyReLU(0.2),
# 2 -> 3
nn.ConvTranspose2d(1024, 512, 2, 1),
nn.LeakyReLU(0.2),
# 3 -> 4
nn.ConvTranspose2d(512, 256, 2, 1),
nn.LeakyReLU(0.2),
# 4 -> 5
nn.ConvTranspose2d(256, 128, 2, 1),
nn.LeakyReLU(0.2),
# 5 -> 7
nn.ConvTranspose2d(128, 64, 3, 1),
nn.LeakyReLU(0.2),
# 7 -> 15
nn.ConvTranspose2d(64, 32, 3, 2),
nn.LeakyReLU(0.2),
# 15 -> 31
nn.ConvTranspose2d(32, 16, 3, 2),
nn.LeakyReLU(0.2),
# 31 -> 63
nn.ConvTranspose2d(16, 8, 3, 2),
nn.LeakyReLU(0.2),
# 63 -> 128
nn.ConvTranspose2d(8, 3, 4, 2),
nn.Tanh()
)
def forward(self, x):
return self.net(x)
class ConditionalDecoder(Decoder):
def __init__(self, latent_size=1024):
super().__init__(latent_size)
self.emb_species = nn.Embedding(len(species), latent_size // 3 + latent_size % 3)
self.emb_class = nn.Embedding(len(classes), latent_size // 3)
self.emb_gender = nn.Embedding(len(genders), latent_size // 3)
self.label_net = nn.Linear(2*latent_size, latent_size)
def forward(self, Z, species_, class_, gender_):
x = self.emb_species(species_)
y = self.emb_class(class_)
z = self.emb_gender(gender_)
x = torch.concat([Z, x, y, z], dim=1)
x = self.label_net(x)
x = self.net(x)
return x
class VariationalAutoEncoder(nn.Module):
def __init__(self, latent_size=1024):
super().__init__()
self.latent_size = latent_size
self.enc = VariationalEncoder(latent_size)
self.dec = Decoder(latent_size)
def forward(self, x):
return self.dec(self.enc(x))
class ConditionalVariationalAutoEncoder(nn.Module):
def __init__(self, latent_size=1024):
super().__init__()
self.latent_size = latent_size
self.enc = ConditionalEncoder(latent_size)
self.dec = ConditionalDecoder(latent_size)
def forward(self, img, species_, class_, gender_):
Z = self.enc(img, species_, class_, gender_)
x = self.dec(Z, species_, class_, gender_)
return x
def show_tensor(Z, ax, **kwargs):
if len(Z.shape) > 3:
Z = Z[0]
if Z.min() < 1:
Z = (Z + 1) / 2
Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))
ax.imshow(Z, **kwargs)
return ax