meeww commited on
Commit
09fccfd
1 Parent(s): c5c5c1d

Upload image_wgan.py

Browse files
Files changed (1) hide show
  1. image_wgan.py +119 -0
image_wgan.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import mkdir
2
+ from os.path import exists
3
+
4
+ import numpy as np
5
+
6
+ import torch
7
+ from torch.autograd import Variable
8
+ from torch.utils.data import DataLoader
9
+ from torchvision.utils import save_image
10
+
11
+ from ml.pytorch.image_dataset import ImageDataset
12
+ from ml.pytorch.wgan.discriminator import Discriminator
13
+ from ml.pytorch.wgan.generator import Generator
14
+
15
+
16
+ class ImageWgan:
17
+ def __init__(
18
+ self,
19
+ image_shape: (int, int, int),
20
+ latent_space_dimension: int = 100,
21
+ use_cuda: bool = False,
22
+ generator_saved_model: str or None = None,
23
+ discriminator_saved_model: str or None = None
24
+ ):
25
+ self.generator = Generator(image_shape, latent_space_dimension, use_cuda, generator_saved_model)
26
+ self.discriminator = Discriminator(image_shape, use_cuda, discriminator_saved_model)
27
+
28
+ self.image_shape = image_shape
29
+ self.latent_space_dimension = latent_space_dimension
30
+ self.use_cuda = use_cuda
31
+ if use_cuda:
32
+ self.generator.cuda()
33
+ self.discriminator.cuda()
34
+
35
+ def train(
36
+ self,
37
+ image_dataset: ImageDataset,
38
+ learning_rate: float = 0.00005,
39
+ batch_size: int = 64,
40
+ workers: int = 8,
41
+ epochs: int = 100,
42
+ clip_value: float = 0.01,
43
+ discriminator_steps: int = 5,
44
+ sample_interval: int = 1000,
45
+ sample_folder: str = 'samples',
46
+ generator_save_file: str = 'generator.model',
47
+ discriminator_save_file: str = 'discriminator.model'
48
+ ):
49
+ if not exists(sample_folder):
50
+ mkdir(sample_folder)
51
+
52
+ generator_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=learning_rate)
53
+ discriminator_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=learning_rate)
54
+
55
+ Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
56
+
57
+ data_loader = torch.utils.data.DataLoader(
58
+ image_dataset,
59
+ batch_size=batch_size,
60
+ shuffle=True,
61
+ num_workers=workers
62
+ )
63
+ batches_done = 0
64
+ for epoch in range(epochs):
65
+ for i, imgs in enumerate(data_loader):
66
+ real_imgs = Variable(imgs.type(Tensor))
67
+
68
+ discriminator_optimizer.zero_grad()
69
+
70
+ # Sample noise as generator input
71
+ z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], self.latent_space_dimension))))
72
+
73
+ fake_imgs = self.generator(z).detach()
74
+ # Adversarial loss
75
+ discriminator_loss = -torch.mean(self.discriminator(real_imgs)) + torch.mean(self.discriminator(fake_imgs))
76
+
77
+ discriminator_loss.backward()
78
+ discriminator_optimizer.step()
79
+
80
+ # Clip weights of discriminator
81
+ for p in self.discriminator.parameters():
82
+ p.data.clamp_(-clip_value, clip_value)
83
+
84
+ # Train the generator every n_critic iterations
85
+ if i % discriminator_steps == 0:
86
+ generator_optimizer.zero_grad()
87
+
88
+ # Generate a batch of images
89
+ gen_imgs = self.generator(z)
90
+ # Adversarial loss
91
+ generator_loss = -torch.mean(self.discriminator(gen_imgs))
92
+
93
+ generator_loss.backward()
94
+ generator_optimizer.step()
95
+
96
+ print(
97
+ f'[Epoch {epoch}/{epochs}] [Batch {batches_done % len(data_loader)}/{len(data_loader)}] ' +
98
+ f'[D loss: {discriminator_loss.item()}] [G loss: {generator_loss.item()}]'
99
+ )
100
+
101
+ if batches_done % sample_interval == 0:
102
+ save_image(gen_imgs.data[:25], f'{sample_folder}/{batches_done}.png', nrow=5, normalize=True)
103
+ batches_done += 1
104
+ self.discriminator.save(discriminator_save_file)
105
+ self.generator.save(generator_save_file)
106
+
107
+ def generate(
108
+ self,
109
+ sample_folder: str = 'samples'
110
+ ):
111
+ if not exists(sample_folder):
112
+ mkdir(sample_folder)
113
+
114
+ Tensor = torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor
115
+ z = Variable(Tensor(np.random.normal(0, 1, (self.image_shape[0], self.latent_space_dimension))))
116
+ gen_imgs = self.generator(z)
117
+ generator_loss = -torch.mean(self.discriminator(gen_imgs))
118
+ generator_loss.backward()
119
+ save_image(gen_imgs.data[:25], f'{sample_folder}/generated.png', nrow=5, normalize=True)