misterbrainley's picture
remove cuda dependance for inference
20bb009
raw
history blame contribute delete
No virus
7.94 kB
import glob
import numpy as np
import pickle
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from torchvision.utils import make_grid
# default image directory (training only)
img_path = '/home/alan/Projects/gen_dnd_art/filtered_images/im128/*pkl'
img_files = glob.glob(img_path)
# determine class names from image directory (training only)
'''
labels = np.array([i.split('/')[-1].split('_')[:3] for i in img_files])
species = np.unique(labels[:, 0]).tolist()
classes = np.unique(labels[:, 1]).tolist()
genders = np.unique(labels[:, 2]).tolist()
'''
# hard code class labels (for application)
species = ['dragonborn', 'dwarf', 'elf', 'gnome', 'halfling', 'human', 'orc', 'tiefling']
classes = [
'barbarian', 'bard', 'cleric', 'druid', 'fighter', 'monk',
'paladin','ranger', 'rogue', 'sorcerer', 'warlock', 'wizard'
]
genders = ['', 'female', 'male']
class ImSet(Dataset):
def __init__(self, img_path=img_path):
super().__init__()
self.img_files = glob.glob(img_path)
self.transform = T.Compose([
T.ToTensor(),
T.ColorJitter(0.1, 0.1, 0.1, 0.1),
T.RandomHorizontalFlip(),
# add random noise and clip
lambda x: torch.clip(torch.randn(x.shape) / 20 + x, 0, 1),
T.Normalize(0.5, 0.5)
])
def __len__(self):
return len(self.img_files)
def __getitem__(self, i):
img_file = self.img_files[i]
# load image
with open(img_file, 'rb') as fid:
img = pickle.load(fid)
# apply transforms
img = self.transform(img).float()
# extract class label
img_fname = img_file.split('/')[-1]
species_, class_, gender_, _, _ = img_fname.split('_')
species_ = species.index(species_)
class_ = classes.index(class_)
gender_ = genders.index(gender_)
return (img_fname, img, species_, class_, gender_)
class VariationalEncoder(nn.Module):
def __init__(self, input_channels=3, latent_size=2048):
super().__init__()
self.latent_size = latent_size
self.net = nn.Sequential(
# 128 -> 63
nn.Conv2d(input_channels, 8, 4, 2),
nn.LeakyReLU(0.2),
# 63 -> 31
nn.Conv2d(8, 16, 3, 2),
nn.LeakyReLU(0.2),
# 31 -> 15
nn.Conv2d(16, 32, 3, 2),
nn.LeakyReLU(0.2),
# 15 -> 7
nn.Conv2d(32, 64, 3, 2),
nn.LeakyReLU(0.2),
# 7 -> 5
nn.Conv2d(64, 128, 3, 1),
nn.LeakyReLU(0.2),
# 5 -> 4
nn.Conv2d(128, 256, 2, 1),
nn.LeakyReLU(0.2),
# 4 -> 3
nn.Conv2d(256, 512, 2, 1),
nn.LeakyReLU(0.2),
# 3 -> 2
nn.Conv2d(512, 1024, 2, 1),
nn.LeakyReLU(0.2),
# 2 -> 1
nn.Conv2d(1024, latent_size, 2, 1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(latent_size, latent_size),
nn.Dropout(0.4)
)
# parameters for variational autoencoder
self.mu = nn.Linear(latent_size, latent_size)
self.sigma = nn.Linear(latent_size, latent_size)
self.N = torch.distributions.Normal(0, 1)
# self.N.loc = self.N.loc.cuda()
# self.N.scale = self.N.scale.cuda()
self.kl = 0
def forward(self, x):
x = self.net(x)
mu = self.mu(x)
sigma = torch.exp(self.sigma(x))
x = mu + sigma * self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return x
class ConditionalEncoder(VariationalEncoder):
def __init__(self, latent_size=2048):
super().__init__(input_channels=4, latent_size=latent_size)
self.emb_species = nn.Embedding(len(species), 128**2 // 3 + 128**2 % 3)
self.emb_class = nn.Embedding(len(classes), 128**2 // 3)
self.emb_gender = nn.Embedding(len(genders), 128**2 // 3)
self.emb_reshape = nn.Unflatten(1, (1, 128, 128))
def forward(self, img, species_, class_, gender_):
x = self.emb_species(species_)
y = self.emb_class(class_)
z = self.emb_gender(gender_)
x = torch.concat([x, y, z], dim=1)
x = self.emb_reshape(x)
x = torch.concat([img, x], dim=1)
x = self.net(x)
mu = self.mu(x)
sigma = torch.exp(self.sigma(x))
x = mu + sigma * self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return x
class Decoder(nn.Module):
def __init__(self, latent_size=2048):
super().__init__()
self.latent_size = latent_size
self.net = nn.Sequential(
nn.Linear(latent_size, latent_size),
nn.Dropout(0.4),
nn.Unflatten(1, (latent_size, 1, 1)),
# 1 -> 2
nn.ConvTranspose2d(latent_size, 1024, 2, 1),
nn.LeakyReLU(0.2),
# 2 -> 3
nn.ConvTranspose2d(1024, 512, 2, 1),
nn.LeakyReLU(0.2),
# 3 -> 4
nn.ConvTranspose2d(512, 256, 2, 1),
nn.LeakyReLU(0.2),
# 4 -> 5
nn.ConvTranspose2d(256, 128, 2, 1),
nn.LeakyReLU(0.2),
# 5 -> 7
nn.ConvTranspose2d(128, 64, 3, 1),
nn.LeakyReLU(0.2),
# 7 -> 15
nn.ConvTranspose2d(64, 32, 3, 2),
nn.LeakyReLU(0.2),
# 15 -> 31
nn.ConvTranspose2d(32, 16, 3, 2),
nn.LeakyReLU(0.2),
# 31 -> 63
nn.ConvTranspose2d(16, 8, 3, 2),
nn.LeakyReLU(0.2),
# 63 -> 128
nn.ConvTranspose2d(8, 3, 4, 2),
nn.Tanh()
)
def forward(self, x):
return self.net(x)
class ConditionalDecoder(Decoder):
def __init__(self, latent_size=1024):
super().__init__(latent_size)
self.emb_species = nn.Embedding(len(species), latent_size // 3 + latent_size % 3)
self.emb_class = nn.Embedding(len(classes), latent_size // 3)
self.emb_gender = nn.Embedding(len(genders), latent_size // 3)
self.label_net = nn.Linear(2*latent_size, latent_size)
def forward(self, Z, species_, class_, gender_):
x = self.emb_species(species_)
y = self.emb_class(class_)
z = self.emb_gender(gender_)
x = torch.concat([Z, x, y, z], dim=1)
x = self.label_net(x)
x = self.net(x)
return x
class VariationalAutoEncoder(nn.Module):
def __init__(self, latent_size=1024):
super().__init__()
self.latent_size = latent_size
self.enc = VariationalEncoder(latent_size)
self.dec = Decoder(latent_size)
def forward(self, x):
return self.dec(self.enc(x))
class ConditionalVariationalAutoEncoder(nn.Module):
def __init__(self, latent_size=1024):
super().__init__()
self.latent_size = latent_size
self.enc = ConditionalEncoder(latent_size)
self.dec = ConditionalDecoder(latent_size)
def forward(self, img, species_, class_, gender_):
Z = self.enc(img, species_, class_, gender_)
x = self.dec(Z, species_, class_, gender_)
return x
def show_tensor(Z, ax, **kwargs):
if len(Z.shape) > 3:
Z = Z[0]
if Z.min() < 1:
Z = (Z + 1) / 2
Z = np.transpose(Z.detach().cpu().numpy(), (1, 2, 0))
ax.imshow(Z, **kwargs)
return ax