|
from os import mkdir |
|
from os.path import exists |
|
|
|
import numpy as np |
|
|
|
import torch |
|
from torch.autograd import Variable |
|
from torch.utils.data import DataLoader |
|
from torchvision.utils import save_image |
|
|
|
from image_dataset import ImageDataset |
|
from discriminator import Discriminator |
|
from generator import Generator |
|
|
|
|
|
class ImageWgan: |
|
def __init__( |
|
self, |
|
image_shape: (int, int, int), |
|
latent_space_dimension: int = 100, |
|
use_cuda: bool = False, |
|
generator_saved_model: str or None = None, |
|
discriminator_saved_model: str or None = None |
|
): |
|
self.generator = Generator(image_shape, latent_space_dimension, use_cuda, generator_saved_model) |
|
self.discriminator = Discriminator(image_shape, use_cuda, discriminator_saved_model) |
|
|
|
self.image_shape = image_shape |
|
self.latent_space_dimension = latent_space_dimension |
|
self.use_cuda = use_cuda |
|
if use_cuda: |
|
self.generator.cuda() |
|
self.discriminator.cuda() |
|
|
|
def train( |
|
self, |
|
image_dataset: ImageDataset, |
|
learning_rate: float = 0.00005, |
|
batch_size: int = 64, |
|
workers: int = 8, |
|
epochs: int = 100, |
|
clip_value: float = 0.01, |
|
discriminator_steps: int = 5, |
|
sample_interval: int = 1000, |
|
sample_folder: str = 'samples', |
|
generator_save_file: str = 'generator.model', |
|
discriminator_save_file: str = 'discriminator.model' |
|
): |
|
if not exists(sample_folder): |
|
mkdir(sample_folder) |
|
|
|
generator_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=learning_rate) |
|
discriminator_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=learning_rate) |
|
|
|
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor |
|
|
|
data_loader = torch.utils.data.DataLoader( |
|
image_dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
num_workers=workers |
|
) |
|
batches_done = 0 |
|
for epoch in range(epochs): |
|
for i, imgs in enumerate(data_loader): |
|
real_imgs = Variable(imgs.type(Tensor)) |
|
|
|
discriminator_optimizer.zero_grad() |
|
|
|
|
|
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], self.latent_space_dimension)))) |
|
|
|
fake_imgs = self.generator(z).detach() |
|
|
|
discriminator_loss = -torch.mean(self.discriminator(real_imgs)) + torch.mean(self.discriminator(fake_imgs)) |
|
|
|
discriminator_loss.backward() |
|
discriminator_optimizer.step() |
|
|
|
|
|
for p in self.discriminator.parameters(): |
|
p.data.clamp_(-clip_value, clip_value) |
|
|
|
|
|
if i % discriminator_steps == 0: |
|
generator_optimizer.zero_grad() |
|
|
|
|
|
gen_imgs = self.generator(z) |
|
|
|
generator_loss = -torch.mean(self.discriminator(gen_imgs)) |
|
|
|
generator_loss.backward() |
|
generator_optimizer.step() |
|
|
|
print( |
|
f'[Epoch {epoch}/{epochs}] [Batch {batches_done % len(data_loader)}/{len(data_loader)}] ' + |
|
f'[D loss: {discriminator_loss.item()}] [G loss: {generator_loss.item()}]' |
|
) |
|
|
|
if batches_done % sample_interval == 0: |
|
save_image(gen_imgs.data[:25], f'{sample_folder}/{batches_done}.png', nrow=5, normalize=True) |
|
batches_done += 1 |
|
self.discriminator.save(discriminator_save_file) |
|
self.generator.save(generator_save_file) |
|
|
|
def generate( |
|
self, |
|
sample_folder: str = 'samples' |
|
): |
|
if not exists(sample_folder): |
|
mkdir(sample_folder) |
|
|
|
Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor |
|
z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension)))) |
|
gen_imgs = self.generator(z) |
|
generator_loss = -torch.mean(self.discriminator(gen_imgs)) |
|
generator_loss.backward() |
|
save_image(gen_imgs.data[:25], f'{sample_folder}/generated.png', nrow=5, normalize=True) |
|
|