Muhammad Rama Nurimani commited on
Commit
82449ec
1 Parent(s): 07e5fb9

test deploy

Browse files
__pycache__/colorization_model.cpython-311.pyc ADDED
Binary file (4.46 kB). View file
 
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ from colorization_model import ColorizationModel # Import your model class
6
+
7
+ # Load the trained generator model
8
+ model_path = "generator.pth"
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Define model options (replace with your configuration)
12
+ class Options:
13
+ input_nc = 1
14
+ output_nc = 2
15
+ ngf = 64
16
+ netG = "unet_256"
17
+ norm = "batch"
18
+ no_dropout = False
19
+ init_type = "normal"
20
+ init_gain = 0.02
21
+ gpu_ids = [0] if torch.cuda.is_available() else []
22
+
23
+ opt = Options()
24
+ generator = ColorizationModel(opt).netG
25
+ generator.load_state_dict(torch.load(model_path, map_location=device))
26
+ generator.eval().to(device)
27
+
28
+ # Define preprocessing and postprocessing steps
29
+ def preprocess_image(image):
30
+ transform = transforms.Compose([
31
+ transforms.Grayscale(num_output_channels=1),
32
+ transforms.Resize((256, 256)),
33
+ transforms.ToTensor(),
34
+ transforms.Normalize(mean=[0.5], std=[0.5])
35
+ ])
36
+ return transform(image).unsqueeze(0).to(device)
37
+
38
+ def postprocess_image(output):
39
+ output = output.squeeze(0).cpu().detach()
40
+ output = torch.cat([output[0:1, :, :] * 50.0 + 50.0, output[1:, :, :] * 110.0], dim=0)
41
+ output_image = transforms.ToPILImage()(output)
42
+ return output_image
43
+
44
+ # Gradio interface function
45
+ def colorize(grayscale_image):
46
+ input_tensor = preprocess_image(grayscale_image)
47
+ with torch.no_grad():
48
+ colorized = generator(input_tensor)
49
+ return postprocess_image(colorized)
50
+
51
+ # Define Gradio interface
52
+ interface = gr.Interface(
53
+ fn=colorize,
54
+ inputs=gr.Image(type="pil", label="Grayscale Image"),
55
+ outputs=gr.Image(type="pil", label="Colorized Image"),
56
+ title="Pix2Pix Image Colorization",
57
+ description="Upload a grayscale image, and the model will colorize it using Pix2Pix GAN."
58
+ )
59
+
60
+ # Launch the app
61
+ if __name__ == "__main__":
62
+ interface.launch()
base_model.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from abc import ABC, abstractmethod
5
+ from . import networks
6
+
7
+
8
+ class BaseModel(ABC):
9
+ """This class is an abstract base class (ABC) for models.
10
+ To create a subclass, you need to implement the following five functions:
11
+ -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
12
+ -- <set_input>: unpack data from dataset and apply preprocessing.
13
+ -- <forward>: produce intermediate results.
14
+ -- <optimize_parameters>: calculate losses, gradients, and update network weights.
15
+ -- <modify_commandline_options>: (optionally) add model-specific options and set default options.
16
+ """
17
+
18
+ def __init__(self, opt):
19
+ """Initialize the BaseModel class.
20
+
21
+ Parameters:
22
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
23
+
24
+ When creating your custom class, you need to implement your own initialization.
25
+ In this function, you should first call <BaseModel.__init__(self, opt)>
26
+ Then, you need to define four lists:
27
+ -- self.loss_names (str list): specify the training losses that you want to plot and save.
28
+ -- self.model_names (str list): define networks used in our training.
29
+ -- self.visual_names (str list): specify the images that you want to display and save.
30
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
31
+ """
32
+ self.opt = opt
33
+ self.gpu_ids = opt.gpu_ids
34
+ self.isTrain = opt.isTrain
35
+ self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
36
+ self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
37
+ if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
38
+ torch.backends.cudnn.benchmark = True
39
+ self.loss_names = []
40
+ self.model_names = []
41
+ self.visual_names = []
42
+ self.optimizers = []
43
+ self.image_paths = []
44
+ self.metric = 0 # used for learning rate policy 'plateau'
45
+
46
+ @staticmethod
47
+ def modify_commandline_options(parser, is_train):
48
+ """Add new model-specific options, and rewrite default values for existing options.
49
+
50
+ Parameters:
51
+ parser -- original option parser
52
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
53
+
54
+ Returns:
55
+ the modified parser.
56
+ """
57
+ return parser
58
+
59
+ @abstractmethod
60
+ def set_input(self, input):
61
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
62
+
63
+ Parameters:
64
+ input (dict): includes the data itself and its metadata information.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def forward(self):
70
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
71
+ pass
72
+
73
+ @abstractmethod
74
+ def optimize_parameters(self):
75
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
76
+ pass
77
+
78
+ def setup(self, opt):
79
+ """Load and print networks; create schedulers
80
+
81
+ Parameters:
82
+ opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83
+ """
84
+ if self.isTrain:
85
+ self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86
+ if not self.isTrain or opt.continue_train:
87
+ load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
88
+ self.load_networks(load_suffix)
89
+ self.print_networks(opt.verbose)
90
+
91
+ def eval(self):
92
+ """Make models eval mode during test time"""
93
+ for name in self.model_names:
94
+ if isinstance(name, str):
95
+ net = getattr(self, 'net' + name)
96
+ net.eval()
97
+
98
+ def test(self):
99
+ """Forward function used in test time.
100
+
101
+ This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
102
+ It also calls <compute_visuals> to produce additional visualization results
103
+ """
104
+ with torch.no_grad():
105
+ self.forward()
106
+ self.compute_visuals()
107
+
108
+ def compute_visuals(self):
109
+ """Calculate additional output images for visdom and HTML visualization"""
110
+ pass
111
+
112
+ def get_image_paths(self):
113
+ """ Return image paths that are used to load current data"""
114
+ return self.image_paths
115
+
116
+ def update_learning_rate(self):
117
+ """Update learning rates for all the networks; called at the end of every epoch"""
118
+ old_lr = self.optimizers[0].param_groups[0]['lr']
119
+ for scheduler in self.schedulers:
120
+ if self.opt.lr_policy == 'plateau':
121
+ scheduler.step(self.metric)
122
+ else:
123
+ scheduler.step()
124
+
125
+ lr = self.optimizers[0].param_groups[0]['lr']
126
+ print('learning rate %.7f -> %.7f' % (old_lr, lr))
127
+
128
+ def get_current_visuals(self):
129
+ """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
130
+ visual_ret = OrderedDict()
131
+ for name in self.visual_names:
132
+ if isinstance(name, str):
133
+ visual_ret[name] = getattr(self, name)
134
+ return visual_ret
135
+
136
+ def get_current_losses(self):
137
+ """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
138
+ errors_ret = OrderedDict()
139
+ for name in self.loss_names:
140
+ if isinstance(name, str):
141
+ errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
142
+ return errors_ret
143
+
144
+ def save_networks(self, epoch):
145
+ """Save all the networks to the disk.
146
+
147
+ Parameters:
148
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
149
+ """
150
+ for name in self.model_names:
151
+ if isinstance(name, str):
152
+ save_filename = '%s_net_%s.pth' % (epoch, name)
153
+ save_path = os.path.join(self.save_dir, save_filename)
154
+ net = getattr(self, 'net' + name)
155
+
156
+ if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157
+ torch.save(net.module.cpu().state_dict(), save_path)
158
+ net.cuda(self.gpu_ids[0])
159
+ else:
160
+ torch.save(net.cpu().state_dict(), save_path)
161
+
162
+ def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
163
+ """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
164
+ key = keys[i]
165
+ if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
166
+ if module.__class__.__name__.startswith('InstanceNorm') and \
167
+ (key == 'running_mean' or key == 'running_var'):
168
+ if getattr(module, key) is None:
169
+ state_dict.pop('.'.join(keys))
170
+ if module.__class__.__name__.startswith('InstanceNorm') and \
171
+ (key == 'num_batches_tracked'):
172
+ state_dict.pop('.'.join(keys))
173
+ else:
174
+ self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
175
+
176
+ def load_networks(self, epoch):
177
+ """Load all the networks from the disk.
178
+
179
+ Parameters:
180
+ epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
181
+ """
182
+ for name in self.model_names:
183
+ if isinstance(name, str):
184
+ load_filename = '%s_net_%s.pth' % (epoch, name)
185
+ load_path = os.path.join(self.save_dir, load_filename)
186
+ net = getattr(self, 'net' + name)
187
+ if isinstance(net, torch.nn.DataParallel):
188
+ net = net.module
189
+ print('loading the model from %s' % load_path)
190
+ # if you are using PyTorch newer than 0.4 (e.g., built from
191
+ # GitHub source), you can remove str() on self.device
192
+ state_dict = torch.load(load_path, map_location=str(self.device))
193
+ if hasattr(state_dict, '_metadata'):
194
+ del state_dict._metadata
195
+
196
+ # patch InstanceNorm checkpoints prior to 0.4
197
+ for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
198
+ self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
199
+ net.load_state_dict(state_dict)
200
+
201
+ def print_networks(self, verbose):
202
+ """Print the total number of parameters in the network and (if verbose) network architecture
203
+
204
+ Parameters:
205
+ verbose (bool) -- if verbose: print the network architecture
206
+ """
207
+ print('---------- Networks initialized -------------')
208
+ for name in self.model_names:
209
+ if isinstance(name, str):
210
+ net = getattr(self, 'net' + name)
211
+ num_params = 0
212
+ for param in net.parameters():
213
+ num_params += param.numel()
214
+ if verbose:
215
+ print(net)
216
+ print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
217
+ print('-----------------------------------------------')
218
+
219
+ def set_requires_grad(self, nets, requires_grad=False):
220
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
221
+ Parameters:
222
+ nets (network list) -- a list of networks
223
+ requires_grad (bool) -- whether the networks require gradients or not
224
+ """
225
+ if not isinstance(nets, list):
226
+ nets = [nets]
227
+ for net in nets:
228
+ if net is not None:
229
+ for param in net.parameters():
230
+ param.requires_grad = requires_grad
colorization_model.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pix2pix_model import Pix2PixModel
2
+ import torch
3
+ from skimage import color # used for lab2rgb
4
+ import numpy as np
5
+
6
+
7
+ class ColorizationModel(Pix2PixModel):
8
+ """This is a subclass of Pix2PixModel for image colorization (black & white image -> colorful images).
9
+
10
+ The model training requires '-dataset_model colorization' dataset.
11
+ It trains a pix2pix model, mapping from L channel to ab channels in Lab color space.
12
+ By default, the colorization dataset will automatically set '--input_nc 1' and '--output_nc 2'.
13
+ """
14
+ @staticmethod
15
+ def modify_commandline_options(parser, is_train=True):
16
+ """Add new dataset-specific options, and rewrite default values for existing options.
17
+
18
+ Parameters:
19
+ parser -- original option parser
20
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
21
+
22
+ Returns:
23
+ the modified parser.
24
+
25
+ By default, we use 'colorization' dataset for this model.
26
+ See the original pix2pix paper (https://arxiv.org/pdf/1611.07004.pdf) and colorization results (Figure 9 in the paper)
27
+ """
28
+ Pix2PixModel.modify_commandline_options(parser, is_train)
29
+ parser.set_defaults(dataset_mode='colorization')
30
+ return parser
31
+
32
+ def __init__(self, opt):
33
+ """Initialize the class.
34
+
35
+ Parameters:
36
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
37
+
38
+ For visualization, we set 'visual_names' as 'real_A' (input real image),
39
+ 'real_B_rgb' (ground truth RGB image), and 'fake_B_rgb' (predicted RGB image)
40
+ We convert the Lab image 'real_B' (inherited from Pix2pixModel) to a RGB image 'real_B_rgb'.
41
+ we convert the Lab image 'fake_B' (inherited from Pix2pixModel) to a RGB image 'fake_B_rgb'.
42
+ """
43
+ # reuse the pix2pix model
44
+ Pix2PixModel.__init__(self, opt)
45
+ # specify the images to be visualized.
46
+ self.visual_names = ['real_A', 'real_B_rgb', 'fake_B_rgb']
47
+
48
+ def lab2rgb(self, L, AB):
49
+ """Convert an Lab tensor image to a RGB numpy output
50
+ Parameters:
51
+ L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
52
+ AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array)
53
+
54
+ Returns:
55
+ rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array)
56
+ """
57
+ AB2 = AB * 110.0
58
+ L2 = (L + 1.0) * 50.0
59
+ Lab = torch.cat([L2, AB2], dim=1)
60
+ Lab = Lab[0].data.cpu().float().numpy()
61
+ Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
62
+ rgb = color.lab2rgb(Lab) * 255
63
+ return rgb
64
+
65
+ def compute_visuals(self):
66
+ """Calculate additional output images for visdom and HTML visualization"""
67
+ self.real_B_rgb = self.lab2rgb(self.real_A, self.real_B)
68
+ self.fake_B_rgb = self.lab2rgb(self.real_A, self.fake_B)
pix2pix_model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .base_model import BaseModel
3
+ from . import networks
4
+
5
+
6
+ class Pix2PixModel(BaseModel):
7
+ """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
8
+
9
+ The model training requires '--dataset_mode aligned' dataset.
10
+ By default, it uses a '--netG unet256' U-Net generator,
11
+ a '--netD basic' discriminator (PatchGAN),
12
+ and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
13
+
14
+ pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
15
+ """
16
+ @staticmethod
17
+ def modify_commandline_options(parser, is_train=True):
18
+ """Add new dataset-specific options, and rewrite default values for existing options.
19
+
20
+ Parameters:
21
+ parser -- original option parser
22
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
23
+
24
+ Returns:
25
+ the modified parser.
26
+
27
+ For pix2pix, we do not use image buffer
28
+ The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
29
+ By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
30
+ """
31
+ # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
32
+ parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
33
+ if is_train:
34
+ parser.set_defaults(pool_size=0, gan_mode='vanilla')
35
+ parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
36
+
37
+ return parser
38
+
39
+ def __init__(self, opt):
40
+ """Initialize the pix2pix class.
41
+
42
+ Parameters:
43
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
44
+ """
45
+ BaseModel.__init__(self, opt)
46
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
47
+ self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
48
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
49
+ self.visual_names = ['real_A', 'fake_B', 'real_B']
50
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
51
+ if self.isTrain:
52
+ self.model_names = ['G', 'D']
53
+ else: # during test time, only load G
54
+ self.model_names = ['G']
55
+ # define networks (both generator and discriminator)
56
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
57
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
58
+
59
+ if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
60
+ self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
61
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
62
+
63
+ if self.isTrain:
64
+ # define loss functions
65
+ self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
66
+ self.criterionL1 = torch.nn.L1Loss()
67
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
68
+ self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
69
+ self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
70
+ self.optimizers.append(self.optimizer_G)
71
+ self.optimizers.append(self.optimizer_D)
72
+
73
+ def set_input(self, input):
74
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
75
+
76
+ Parameters:
77
+ input (dict): include the data itself and its metadata information.
78
+
79
+ The option 'direction' can be used to swap images in domain A and domain B.
80
+ """
81
+ AtoB = self.opt.direction == 'AtoB'
82
+ self.real_A = input['A' if AtoB else 'B'].to(self.device)
83
+ self.real_B = input['B' if AtoB else 'A'].to(self.device)
84
+ self.image_paths = input['A_paths' if AtoB else 'B_paths']
85
+
86
+ def forward(self):
87
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
88
+ self.fake_B = self.netG(self.real_A) # G(A)
89
+
90
+ def backward_D(self):
91
+ """Calculate GAN loss for the discriminator"""
92
+ # Fake; stop backprop to the generator by detaching fake_B
93
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
94
+ pred_fake = self.netD(fake_AB.detach())
95
+ self.loss_D_fake = self.criterionGAN(pred_fake, False)
96
+ # Real
97
+ real_AB = torch.cat((self.real_A, self.real_B), 1)
98
+ pred_real = self.netD(real_AB)
99
+ self.loss_D_real = self.criterionGAN(pred_real, True)
100
+ # combine loss and calculate gradients
101
+ self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
102
+ self.loss_D.backward()
103
+
104
+ def backward_G(self):
105
+ """Calculate GAN and L1 loss for the generator"""
106
+ # First, G(A) should fake the discriminator
107
+ fake_AB = torch.cat((self.real_A, self.fake_B), 1)
108
+ pred_fake = self.netD(fake_AB)
109
+ self.loss_G_GAN = self.criterionGAN(pred_fake, True)
110
+ # Second, G(A) = B
111
+ self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
112
+ # combine loss and calculate gradients
113
+ self.loss_G = self.loss_G_GAN + self.loss_G_L1
114
+ self.loss_G.backward()
115
+
116
+ def optimize_parameters(self):
117
+ self.forward() # compute fake images: G(A)
118
+ # update D
119
+ self.set_requires_grad(self.netD, True) # enable backprop for D
120
+ self.optimizer_D.zero_grad() # set D's gradients to zero
121
+ self.backward_D() # calculate gradients for D
122
+ self.optimizer_D.step() # update D's weights
123
+ # update G
124
+ self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
125
+ self.optimizer_G.zero_grad() # set G's gradients to zero
126
+ self.backward_G() # calculate graidents for G
127
+ self.optimizer_G.step() # update G's weights
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ numpy
5
+ scikit-image
test_model.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base_model import BaseModel
2
+ from . import networks
3
+
4
+
5
+ class TestModel(BaseModel):
6
+ """ This TesteModel can be used to generate CycleGAN results for only one direction.
7
+ This model will automatically set '--dataset_mode single', which only loads the images from one collection.
8
+
9
+ See the test instruction for more details.
10
+ """
11
+ @staticmethod
12
+ def modify_commandline_options(parser, is_train=True):
13
+ """Add new dataset-specific options, and rewrite default values for existing options.
14
+
15
+ Parameters:
16
+ parser -- original option parser
17
+ is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
18
+
19
+ Returns:
20
+ the modified parser.
21
+
22
+ The model can only be used during test time. It requires '--dataset_mode single'.
23
+ You need to specify the network using the option '--model_suffix'.
24
+ """
25
+ assert not is_train, 'TestModel cannot be used during training time'
26
+ parser.set_defaults(dataset_mode='single')
27
+ parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
28
+
29
+ return parser
30
+
31
+ def __init__(self, opt):
32
+ """Initialize the pix2pix class.
33
+
34
+ Parameters:
35
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
36
+ """
37
+ assert(not opt.isTrain)
38
+ BaseModel.__init__(self, opt)
39
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
40
+ self.loss_names = []
41
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
42
+ self.visual_names = ['real', 'fake']
43
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
44
+ self.model_names = ['G' + opt.model_suffix] # only generator is needed.
45
+ self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
46
+ opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
47
+
48
+ # assigns the model to self.netG_[suffix] so that it can be loaded
49
+ # please see <BaseModel.load_networks>
50
+ setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
51
+
52
+ def set_input(self, input):
53
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
54
+
55
+ Parameters:
56
+ input: a dictionary that contains the data itself and its metadata information.
57
+
58
+ We need to use 'single_dataset' dataset mode. It only load images from one domain.
59
+ """
60
+ self.real = input['A'].to(self.device)
61
+ self.image_paths = input['A_paths']
62
+
63
+ def forward(self):
64
+ """Run forward pass."""
65
+ self.fake = self.netG(self.real) # G(real)
66
+
67
+ def optimize_parameters(self):
68
+ """No optimization for test model."""
69
+ pass