File size: 4,496 Bytes
09fccfd
 
 
 
 
 
 
 
 
 
53b1f7f
8c693de
 
09fccfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2825685
09fccfd
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from os import mkdir
from os.path import exists

import numpy as np

import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.utils import save_image

from image_dataset import ImageDataset
from discriminator import Discriminator
from generator import Generator


class ImageWgan:
    def __init__(
        self,
        image_shape: (int, int, int),
        latent_space_dimension: int = 100,
        use_cuda: bool = False,
        generator_saved_model: str or None = None,
        discriminator_saved_model: str or None = None
    ):
        self.generator = Generator(image_shape, latent_space_dimension, use_cuda, generator_saved_model)
        self.discriminator = Discriminator(image_shape, use_cuda, discriminator_saved_model)

        self.image_shape = image_shape
        self.latent_space_dimension = latent_space_dimension
        self.use_cuda = use_cuda
        if use_cuda:
            self.generator.cuda()
            self.discriminator.cuda()

    def train(
        self,
        image_dataset: ImageDataset,
        learning_rate: float = 0.00005,
        batch_size: int = 64,
        workers: int = 8,
        epochs: int = 100,
        clip_value: float = 0.01,
        discriminator_steps: int = 5,
        sample_interval: int = 1000,
        sample_folder: str = 'samples',
        generator_save_file: str = 'generator.model',
        discriminator_save_file: str = 'discriminator.model'
    ):
        if not exists(sample_folder):
            mkdir(sample_folder)

        generator_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=learning_rate)
        discriminator_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=learning_rate)

        Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor

        data_loader = torch.utils.data.DataLoader(
            image_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=workers
        )
        batches_done = 0
        for epoch in range(epochs):
            for i, imgs in enumerate(data_loader):
                real_imgs = Variable(imgs.type(Tensor))

                discriminator_optimizer.zero_grad()

                # Sample noise as generator input
                z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], self.latent_space_dimension))))

                fake_imgs = self.generator(z).detach()
                # Adversarial loss
                discriminator_loss = -torch.mean(self.discriminator(real_imgs)) + torch.mean(self.discriminator(fake_imgs))

                discriminator_loss.backward()
                discriminator_optimizer.step()

                # Clip weights of discriminator
                for p in self.discriminator.parameters():
                    p.data.clamp_(-clip_value, clip_value)

                # Train the generator every n_critic iterations
                if i % discriminator_steps == 0:
                    generator_optimizer.zero_grad()

                    # Generate a batch of images
                    gen_imgs = self.generator(z)
                    # Adversarial loss
                    generator_loss = -torch.mean(self.discriminator(gen_imgs))

                    generator_loss.backward()
                    generator_optimizer.step()

                    print(
                        f'[Epoch {epoch}/{epochs}] [Batch {batches_done % len(data_loader)}/{len(data_loader)}] ' +
                        f'[D loss: {discriminator_loss.item()}] [G loss: {generator_loss.item()}]'
                    )

                if batches_done % sample_interval == 0:
                    save_image(gen_imgs.data[:25], f'{sample_folder}/{batches_done}.png', nrow=5, normalize=True)
                batches_done += 1
            self.discriminator.save(discriminator_save_file)
            self.generator.save(generator_save_file)

    def generate(
        self,
        sample_folder: str = 'samples'
    ):
        if not exists(sample_folder):
            mkdir(sample_folder)

        Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
        z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension))))
        gen_imgs = self.generator(z)
        generator_loss = -torch.mean(self.discriminator(gen_imgs))
        generator_loss.backward()
        save_image(gen_imgs.data[:25], f'{sample_folder}/generated.png', nrow=5, normalize=True)