Minecraft_Skin_Generator / image_wgan.py
meeww's picture
Update image_wgan.py
2825685
raw
history blame
No virus
4.5 kB
from os import mkdir
from os.path import exists
import numpy as np
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from image_dataset import ImageDataset
from discriminator import Discriminator
from generator import Generator
class ImageWgan:
def __init__(
self,
image_shape: (int, int, int),
latent_space_dimension: int = 100,
use_cuda: bool = False,
generator_saved_model: str or None = None,
discriminator_saved_model: str or None = None
):
self.generator = Generator(image_shape, latent_space_dimension, use_cuda, generator_saved_model)
self.discriminator = Discriminator(image_shape, use_cuda, discriminator_saved_model)
self.image_shape = image_shape
self.latent_space_dimension = latent_space_dimension
self.use_cuda = use_cuda
if use_cuda:
self.generator.cuda()
self.discriminator.cuda()
def train(
self,
image_dataset: ImageDataset,
learning_rate: float = 0.00005,
batch_size: int = 64,
workers: int = 8,
epochs: int = 100,
clip_value: float = 0.01,
discriminator_steps: int = 5,
sample_interval: int = 1000,
sample_folder: str = 'samples',
generator_save_file: str = 'generator.model',
discriminator_save_file: str = 'discriminator.model'
):
if not exists(sample_folder):
mkdir(sample_folder)
generator_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=learning_rate)
discriminator_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=learning_rate)
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
data_loader = torch.utils.data.DataLoader(
image_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=workers
)
batches_done = 0
for epoch in range(epochs):
for i, imgs in enumerate(data_loader):
real_imgs = Variable(imgs.type(Tensor))
discriminator_optimizer.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], self.latent_space_dimension))))
fake_imgs = self.generator(z).detach()
# Adversarial loss
discriminator_loss = -torch.mean(self.discriminator(real_imgs)) + torch.mean(self.discriminator(fake_imgs))
discriminator_loss.backward()
discriminator_optimizer.step()
# Clip weights of discriminator
for p in self.discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# Train the generator every n_critic iterations
if i % discriminator_steps == 0:
generator_optimizer.zero_grad()
# Generate a batch of images
gen_imgs = self.generator(z)
# Adversarial loss
generator_loss = -torch.mean(self.discriminator(gen_imgs))
generator_loss.backward()
generator_optimizer.step()
print(
f'[Epoch {epoch}/{epochs}] [Batch {batches_done % len(data_loader)}/{len(data_loader)}] ' +
f'[D loss: {discriminator_loss.item()}] [G loss: {generator_loss.item()}]'
)
if batches_done % sample_interval == 0:
save_image(gen_imgs.data[:25], f'{sample_folder}/{batches_done}.png', nrow=5, normalize=True)
batches_done += 1
self.discriminator.save(discriminator_save_file)
self.generator.save(generator_save_file)
def generate(
self,
sample_folder: str = 'samples'
):
if not exists(sample_folder):
mkdir(sample_folder)
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension))))
gen_imgs = self.generator(z)
generator_loss = -torch.mean(self.discriminator(gen_imgs))
generator_loss.backward()
save_image(gen_imgs.data[:25], f'{sample_folder}/generated.png', nrow=5, normalize=True)