Spaces:
Runtime error
Runtime error
Muhammad Rama Nurimani
commited on
Commit
•
82449ec
1
Parent(s):
07e5fb9
test deploy
Browse files- __pycache__/colorization_model.cpython-311.pyc +0 -0
- app.py +62 -0
- base_model.py +230 -0
- colorization_model.py +68 -0
- pix2pix_model.py +127 -0
- requirements.txt +5 -0
- test_model.py +69 -0
__pycache__/colorization_model.cpython-311.pyc
ADDED
Binary file (4.46 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
from colorization_model import ColorizationModel # Import your model class
|
6 |
+
|
7 |
+
# Load the trained generator model
|
8 |
+
model_path = "generator.pth"
|
9 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
|
11 |
+
# Define model options (replace with your configuration)
|
12 |
+
class Options:
|
13 |
+
input_nc = 1
|
14 |
+
output_nc = 2
|
15 |
+
ngf = 64
|
16 |
+
netG = "unet_256"
|
17 |
+
norm = "batch"
|
18 |
+
no_dropout = False
|
19 |
+
init_type = "normal"
|
20 |
+
init_gain = 0.02
|
21 |
+
gpu_ids = [0] if torch.cuda.is_available() else []
|
22 |
+
|
23 |
+
opt = Options()
|
24 |
+
generator = ColorizationModel(opt).netG
|
25 |
+
generator.load_state_dict(torch.load(model_path, map_location=device))
|
26 |
+
generator.eval().to(device)
|
27 |
+
|
28 |
+
# Define preprocessing and postprocessing steps
|
29 |
+
def preprocess_image(image):
|
30 |
+
transform = transforms.Compose([
|
31 |
+
transforms.Grayscale(num_output_channels=1),
|
32 |
+
transforms.Resize((256, 256)),
|
33 |
+
transforms.ToTensor(),
|
34 |
+
transforms.Normalize(mean=[0.5], std=[0.5])
|
35 |
+
])
|
36 |
+
return transform(image).unsqueeze(0).to(device)
|
37 |
+
|
38 |
+
def postprocess_image(output):
|
39 |
+
output = output.squeeze(0).cpu().detach()
|
40 |
+
output = torch.cat([output[0:1, :, :] * 50.0 + 50.0, output[1:, :, :] * 110.0], dim=0)
|
41 |
+
output_image = transforms.ToPILImage()(output)
|
42 |
+
return output_image
|
43 |
+
|
44 |
+
# Gradio interface function
|
45 |
+
def colorize(grayscale_image):
|
46 |
+
input_tensor = preprocess_image(grayscale_image)
|
47 |
+
with torch.no_grad():
|
48 |
+
colorized = generator(input_tensor)
|
49 |
+
return postprocess_image(colorized)
|
50 |
+
|
51 |
+
# Define Gradio interface
|
52 |
+
interface = gr.Interface(
|
53 |
+
fn=colorize,
|
54 |
+
inputs=gr.Image(type="pil", label="Grayscale Image"),
|
55 |
+
outputs=gr.Image(type="pil", label="Colorized Image"),
|
56 |
+
title="Pix2Pix Image Colorization",
|
57 |
+
description="Upload a grayscale image, and the model will colorize it using Pix2Pix GAN."
|
58 |
+
)
|
59 |
+
|
60 |
+
# Launch the app
|
61 |
+
if __name__ == "__main__":
|
62 |
+
interface.launch()
|
base_model.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from abc import ABC, abstractmethod
|
5 |
+
from . import networks
|
6 |
+
|
7 |
+
|
8 |
+
class BaseModel(ABC):
|
9 |
+
"""This class is an abstract base class (ABC) for models.
|
10 |
+
To create a subclass, you need to implement the following five functions:
|
11 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
12 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
13 |
+
-- <forward>: produce intermediate results.
|
14 |
+
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
|
15 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, opt):
|
19 |
+
"""Initialize the BaseModel class.
|
20 |
+
|
21 |
+
Parameters:
|
22 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
23 |
+
|
24 |
+
When creating your custom class, you need to implement your own initialization.
|
25 |
+
In this function, you should first call <BaseModel.__init__(self, opt)>
|
26 |
+
Then, you need to define four lists:
|
27 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
28 |
+
-- self.model_names (str list): define networks used in our training.
|
29 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
30 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
|
31 |
+
"""
|
32 |
+
self.opt = opt
|
33 |
+
self.gpu_ids = opt.gpu_ids
|
34 |
+
self.isTrain = opt.isTrain
|
35 |
+
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
|
36 |
+
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
|
37 |
+
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
|
38 |
+
torch.backends.cudnn.benchmark = True
|
39 |
+
self.loss_names = []
|
40 |
+
self.model_names = []
|
41 |
+
self.visual_names = []
|
42 |
+
self.optimizers = []
|
43 |
+
self.image_paths = []
|
44 |
+
self.metric = 0 # used for learning rate policy 'plateau'
|
45 |
+
|
46 |
+
@staticmethod
|
47 |
+
def modify_commandline_options(parser, is_train):
|
48 |
+
"""Add new model-specific options, and rewrite default values for existing options.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
parser -- original option parser
|
52 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
the modified parser.
|
56 |
+
"""
|
57 |
+
return parser
|
58 |
+
|
59 |
+
@abstractmethod
|
60 |
+
def set_input(self, input):
|
61 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
62 |
+
|
63 |
+
Parameters:
|
64 |
+
input (dict): includes the data itself and its metadata information.
|
65 |
+
"""
|
66 |
+
pass
|
67 |
+
|
68 |
+
@abstractmethod
|
69 |
+
def forward(self):
|
70 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
71 |
+
pass
|
72 |
+
|
73 |
+
@abstractmethod
|
74 |
+
def optimize_parameters(self):
|
75 |
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
76 |
+
pass
|
77 |
+
|
78 |
+
def setup(self, opt):
|
79 |
+
"""Load and print networks; create schedulers
|
80 |
+
|
81 |
+
Parameters:
|
82 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
83 |
+
"""
|
84 |
+
if self.isTrain:
|
85 |
+
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
86 |
+
if not self.isTrain or opt.continue_train:
|
87 |
+
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
|
88 |
+
self.load_networks(load_suffix)
|
89 |
+
self.print_networks(opt.verbose)
|
90 |
+
|
91 |
+
def eval(self):
|
92 |
+
"""Make models eval mode during test time"""
|
93 |
+
for name in self.model_names:
|
94 |
+
if isinstance(name, str):
|
95 |
+
net = getattr(self, 'net' + name)
|
96 |
+
net.eval()
|
97 |
+
|
98 |
+
def test(self):
|
99 |
+
"""Forward function used in test time.
|
100 |
+
|
101 |
+
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
|
102 |
+
It also calls <compute_visuals> to produce additional visualization results
|
103 |
+
"""
|
104 |
+
with torch.no_grad():
|
105 |
+
self.forward()
|
106 |
+
self.compute_visuals()
|
107 |
+
|
108 |
+
def compute_visuals(self):
|
109 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
110 |
+
pass
|
111 |
+
|
112 |
+
def get_image_paths(self):
|
113 |
+
""" Return image paths that are used to load current data"""
|
114 |
+
return self.image_paths
|
115 |
+
|
116 |
+
def update_learning_rate(self):
|
117 |
+
"""Update learning rates for all the networks; called at the end of every epoch"""
|
118 |
+
old_lr = self.optimizers[0].param_groups[0]['lr']
|
119 |
+
for scheduler in self.schedulers:
|
120 |
+
if self.opt.lr_policy == 'plateau':
|
121 |
+
scheduler.step(self.metric)
|
122 |
+
else:
|
123 |
+
scheduler.step()
|
124 |
+
|
125 |
+
lr = self.optimizers[0].param_groups[0]['lr']
|
126 |
+
print('learning rate %.7f -> %.7f' % (old_lr, lr))
|
127 |
+
|
128 |
+
def get_current_visuals(self):
|
129 |
+
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
|
130 |
+
visual_ret = OrderedDict()
|
131 |
+
for name in self.visual_names:
|
132 |
+
if isinstance(name, str):
|
133 |
+
visual_ret[name] = getattr(self, name)
|
134 |
+
return visual_ret
|
135 |
+
|
136 |
+
def get_current_losses(self):
|
137 |
+
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
|
138 |
+
errors_ret = OrderedDict()
|
139 |
+
for name in self.loss_names:
|
140 |
+
if isinstance(name, str):
|
141 |
+
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
|
142 |
+
return errors_ret
|
143 |
+
|
144 |
+
def save_networks(self, epoch):
|
145 |
+
"""Save all the networks to the disk.
|
146 |
+
|
147 |
+
Parameters:
|
148 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
149 |
+
"""
|
150 |
+
for name in self.model_names:
|
151 |
+
if isinstance(name, str):
|
152 |
+
save_filename = '%s_net_%s.pth' % (epoch, name)
|
153 |
+
save_path = os.path.join(self.save_dir, save_filename)
|
154 |
+
net = getattr(self, 'net' + name)
|
155 |
+
|
156 |
+
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
|
157 |
+
torch.save(net.module.cpu().state_dict(), save_path)
|
158 |
+
net.cuda(self.gpu_ids[0])
|
159 |
+
else:
|
160 |
+
torch.save(net.cpu().state_dict(), save_path)
|
161 |
+
|
162 |
+
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
|
163 |
+
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
|
164 |
+
key = keys[i]
|
165 |
+
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
|
166 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
167 |
+
(key == 'running_mean' or key == 'running_var'):
|
168 |
+
if getattr(module, key) is None:
|
169 |
+
state_dict.pop('.'.join(keys))
|
170 |
+
if module.__class__.__name__.startswith('InstanceNorm') and \
|
171 |
+
(key == 'num_batches_tracked'):
|
172 |
+
state_dict.pop('.'.join(keys))
|
173 |
+
else:
|
174 |
+
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
|
175 |
+
|
176 |
+
def load_networks(self, epoch):
|
177 |
+
"""Load all the networks from the disk.
|
178 |
+
|
179 |
+
Parameters:
|
180 |
+
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
181 |
+
"""
|
182 |
+
for name in self.model_names:
|
183 |
+
if isinstance(name, str):
|
184 |
+
load_filename = '%s_net_%s.pth' % (epoch, name)
|
185 |
+
load_path = os.path.join(self.save_dir, load_filename)
|
186 |
+
net = getattr(self, 'net' + name)
|
187 |
+
if isinstance(net, torch.nn.DataParallel):
|
188 |
+
net = net.module
|
189 |
+
print('loading the model from %s' % load_path)
|
190 |
+
# if you are using PyTorch newer than 0.4 (e.g., built from
|
191 |
+
# GitHub source), you can remove str() on self.device
|
192 |
+
state_dict = torch.load(load_path, map_location=str(self.device))
|
193 |
+
if hasattr(state_dict, '_metadata'):
|
194 |
+
del state_dict._metadata
|
195 |
+
|
196 |
+
# patch InstanceNorm checkpoints prior to 0.4
|
197 |
+
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
198 |
+
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
199 |
+
net.load_state_dict(state_dict)
|
200 |
+
|
201 |
+
def print_networks(self, verbose):
|
202 |
+
"""Print the total number of parameters in the network and (if verbose) network architecture
|
203 |
+
|
204 |
+
Parameters:
|
205 |
+
verbose (bool) -- if verbose: print the network architecture
|
206 |
+
"""
|
207 |
+
print('---------- Networks initialized -------------')
|
208 |
+
for name in self.model_names:
|
209 |
+
if isinstance(name, str):
|
210 |
+
net = getattr(self, 'net' + name)
|
211 |
+
num_params = 0
|
212 |
+
for param in net.parameters():
|
213 |
+
num_params += param.numel()
|
214 |
+
if verbose:
|
215 |
+
print(net)
|
216 |
+
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
|
217 |
+
print('-----------------------------------------------')
|
218 |
+
|
219 |
+
def set_requires_grad(self, nets, requires_grad=False):
|
220 |
+
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
|
221 |
+
Parameters:
|
222 |
+
nets (network list) -- a list of networks
|
223 |
+
requires_grad (bool) -- whether the networks require gradients or not
|
224 |
+
"""
|
225 |
+
if not isinstance(nets, list):
|
226 |
+
nets = [nets]
|
227 |
+
for net in nets:
|
228 |
+
if net is not None:
|
229 |
+
for param in net.parameters():
|
230 |
+
param.requires_grad = requires_grad
|
colorization_model.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pix2pix_model import Pix2PixModel
|
2 |
+
import torch
|
3 |
+
from skimage import color # used for lab2rgb
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class ColorizationModel(Pix2PixModel):
|
8 |
+
"""This is a subclass of Pix2PixModel for image colorization (black & white image -> colorful images).
|
9 |
+
|
10 |
+
The model training requires '-dataset_model colorization' dataset.
|
11 |
+
It trains a pix2pix model, mapping from L channel to ab channels in Lab color space.
|
12 |
+
By default, the colorization dataset will automatically set '--input_nc 1' and '--output_nc 2'.
|
13 |
+
"""
|
14 |
+
@staticmethod
|
15 |
+
def modify_commandline_options(parser, is_train=True):
|
16 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
parser -- original option parser
|
20 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
the modified parser.
|
24 |
+
|
25 |
+
By default, we use 'colorization' dataset for this model.
|
26 |
+
See the original pix2pix paper (https://arxiv.org/pdf/1611.07004.pdf) and colorization results (Figure 9 in the paper)
|
27 |
+
"""
|
28 |
+
Pix2PixModel.modify_commandline_options(parser, is_train)
|
29 |
+
parser.set_defaults(dataset_mode='colorization')
|
30 |
+
return parser
|
31 |
+
|
32 |
+
def __init__(self, opt):
|
33 |
+
"""Initialize the class.
|
34 |
+
|
35 |
+
Parameters:
|
36 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
37 |
+
|
38 |
+
For visualization, we set 'visual_names' as 'real_A' (input real image),
|
39 |
+
'real_B_rgb' (ground truth RGB image), and 'fake_B_rgb' (predicted RGB image)
|
40 |
+
We convert the Lab image 'real_B' (inherited from Pix2pixModel) to a RGB image 'real_B_rgb'.
|
41 |
+
we convert the Lab image 'fake_B' (inherited from Pix2pixModel) to a RGB image 'fake_B_rgb'.
|
42 |
+
"""
|
43 |
+
# reuse the pix2pix model
|
44 |
+
Pix2PixModel.__init__(self, opt)
|
45 |
+
# specify the images to be visualized.
|
46 |
+
self.visual_names = ['real_A', 'real_B_rgb', 'fake_B_rgb']
|
47 |
+
|
48 |
+
def lab2rgb(self, L, AB):
|
49 |
+
"""Convert an Lab tensor image to a RGB numpy output
|
50 |
+
Parameters:
|
51 |
+
L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
|
52 |
+
AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array)
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array)
|
56 |
+
"""
|
57 |
+
AB2 = AB * 110.0
|
58 |
+
L2 = (L + 1.0) * 50.0
|
59 |
+
Lab = torch.cat([L2, AB2], dim=1)
|
60 |
+
Lab = Lab[0].data.cpu().float().numpy()
|
61 |
+
Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
|
62 |
+
rgb = color.lab2rgb(Lab) * 255
|
63 |
+
return rgb
|
64 |
+
|
65 |
+
def compute_visuals(self):
|
66 |
+
"""Calculate additional output images for visdom and HTML visualization"""
|
67 |
+
self.real_B_rgb = self.lab2rgb(self.real_A, self.real_B)
|
68 |
+
self.fake_B_rgb = self.lab2rgb(self.real_A, self.fake_B)
|
pix2pix_model.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .base_model import BaseModel
|
3 |
+
from . import networks
|
4 |
+
|
5 |
+
|
6 |
+
class Pix2PixModel(BaseModel):
|
7 |
+
""" This class implements the pix2pix model, for learning a mapping from input images to output images given paired data.
|
8 |
+
|
9 |
+
The model training requires '--dataset_mode aligned' dataset.
|
10 |
+
By default, it uses a '--netG unet256' U-Net generator,
|
11 |
+
a '--netD basic' discriminator (PatchGAN),
|
12 |
+
and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper).
|
13 |
+
|
14 |
+
pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf
|
15 |
+
"""
|
16 |
+
@staticmethod
|
17 |
+
def modify_commandline_options(parser, is_train=True):
|
18 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
19 |
+
|
20 |
+
Parameters:
|
21 |
+
parser -- original option parser
|
22 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
the modified parser.
|
26 |
+
|
27 |
+
For pix2pix, we do not use image buffer
|
28 |
+
The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1
|
29 |
+
By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets.
|
30 |
+
"""
|
31 |
+
# changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/)
|
32 |
+
parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='aligned')
|
33 |
+
if is_train:
|
34 |
+
parser.set_defaults(pool_size=0, gan_mode='vanilla')
|
35 |
+
parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
|
36 |
+
|
37 |
+
return parser
|
38 |
+
|
39 |
+
def __init__(self, opt):
|
40 |
+
"""Initialize the pix2pix class.
|
41 |
+
|
42 |
+
Parameters:
|
43 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
44 |
+
"""
|
45 |
+
BaseModel.__init__(self, opt)
|
46 |
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
47 |
+
self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake']
|
48 |
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
49 |
+
self.visual_names = ['real_A', 'fake_B', 'real_B']
|
50 |
+
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
51 |
+
if self.isTrain:
|
52 |
+
self.model_names = ['G', 'D']
|
53 |
+
else: # during test time, only load G
|
54 |
+
self.model_names = ['G']
|
55 |
+
# define networks (both generator and discriminator)
|
56 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
57 |
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
58 |
+
|
59 |
+
if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
60 |
+
self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
61 |
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
62 |
+
|
63 |
+
if self.isTrain:
|
64 |
+
# define loss functions
|
65 |
+
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)
|
66 |
+
self.criterionL1 = torch.nn.L1Loss()
|
67 |
+
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
68 |
+
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
69 |
+
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
|
70 |
+
self.optimizers.append(self.optimizer_G)
|
71 |
+
self.optimizers.append(self.optimizer_D)
|
72 |
+
|
73 |
+
def set_input(self, input):
|
74 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
75 |
+
|
76 |
+
Parameters:
|
77 |
+
input (dict): include the data itself and its metadata information.
|
78 |
+
|
79 |
+
The option 'direction' can be used to swap images in domain A and domain B.
|
80 |
+
"""
|
81 |
+
AtoB = self.opt.direction == 'AtoB'
|
82 |
+
self.real_A = input['A' if AtoB else 'B'].to(self.device)
|
83 |
+
self.real_B = input['B' if AtoB else 'A'].to(self.device)
|
84 |
+
self.image_paths = input['A_paths' if AtoB else 'B_paths']
|
85 |
+
|
86 |
+
def forward(self):
|
87 |
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
88 |
+
self.fake_B = self.netG(self.real_A) # G(A)
|
89 |
+
|
90 |
+
def backward_D(self):
|
91 |
+
"""Calculate GAN loss for the discriminator"""
|
92 |
+
# Fake; stop backprop to the generator by detaching fake_B
|
93 |
+
fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
|
94 |
+
pred_fake = self.netD(fake_AB.detach())
|
95 |
+
self.loss_D_fake = self.criterionGAN(pred_fake, False)
|
96 |
+
# Real
|
97 |
+
real_AB = torch.cat((self.real_A, self.real_B), 1)
|
98 |
+
pred_real = self.netD(real_AB)
|
99 |
+
self.loss_D_real = self.criterionGAN(pred_real, True)
|
100 |
+
# combine loss and calculate gradients
|
101 |
+
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
|
102 |
+
self.loss_D.backward()
|
103 |
+
|
104 |
+
def backward_G(self):
|
105 |
+
"""Calculate GAN and L1 loss for the generator"""
|
106 |
+
# First, G(A) should fake the discriminator
|
107 |
+
fake_AB = torch.cat((self.real_A, self.fake_B), 1)
|
108 |
+
pred_fake = self.netD(fake_AB)
|
109 |
+
self.loss_G_GAN = self.criterionGAN(pred_fake, True)
|
110 |
+
# Second, G(A) = B
|
111 |
+
self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1
|
112 |
+
# combine loss and calculate gradients
|
113 |
+
self.loss_G = self.loss_G_GAN + self.loss_G_L1
|
114 |
+
self.loss_G.backward()
|
115 |
+
|
116 |
+
def optimize_parameters(self):
|
117 |
+
self.forward() # compute fake images: G(A)
|
118 |
+
# update D
|
119 |
+
self.set_requires_grad(self.netD, True) # enable backprop for D
|
120 |
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
121 |
+
self.backward_D() # calculate gradients for D
|
122 |
+
self.optimizer_D.step() # update D's weights
|
123 |
+
# update G
|
124 |
+
self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G
|
125 |
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
126 |
+
self.backward_G() # calculate graidents for G
|
127 |
+
self.optimizer_G.step() # update G's weights
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
gradio
|
4 |
+
numpy
|
5 |
+
scikit-image
|
test_model.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_model import BaseModel
|
2 |
+
from . import networks
|
3 |
+
|
4 |
+
|
5 |
+
class TestModel(BaseModel):
|
6 |
+
""" This TesteModel can be used to generate CycleGAN results for only one direction.
|
7 |
+
This model will automatically set '--dataset_mode single', which only loads the images from one collection.
|
8 |
+
|
9 |
+
See the test instruction for more details.
|
10 |
+
"""
|
11 |
+
@staticmethod
|
12 |
+
def modify_commandline_options(parser, is_train=True):
|
13 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
14 |
+
|
15 |
+
Parameters:
|
16 |
+
parser -- original option parser
|
17 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
the modified parser.
|
21 |
+
|
22 |
+
The model can only be used during test time. It requires '--dataset_mode single'.
|
23 |
+
You need to specify the network using the option '--model_suffix'.
|
24 |
+
"""
|
25 |
+
assert not is_train, 'TestModel cannot be used during training time'
|
26 |
+
parser.set_defaults(dataset_mode='single')
|
27 |
+
parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
|
28 |
+
|
29 |
+
return parser
|
30 |
+
|
31 |
+
def __init__(self, opt):
|
32 |
+
"""Initialize the pix2pix class.
|
33 |
+
|
34 |
+
Parameters:
|
35 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
36 |
+
"""
|
37 |
+
assert(not opt.isTrain)
|
38 |
+
BaseModel.__init__(self, opt)
|
39 |
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
40 |
+
self.loss_names = []
|
41 |
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
42 |
+
self.visual_names = ['real', 'fake']
|
43 |
+
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
44 |
+
self.model_names = ['G' + opt.model_suffix] # only generator is needed.
|
45 |
+
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
|
46 |
+
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
47 |
+
|
48 |
+
# assigns the model to self.netG_[suffix] so that it can be loaded
|
49 |
+
# please see <BaseModel.load_networks>
|
50 |
+
setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
|
51 |
+
|
52 |
+
def set_input(self, input):
|
53 |
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
input: a dictionary that contains the data itself and its metadata information.
|
57 |
+
|
58 |
+
We need to use 'single_dataset' dataset mode. It only load images from one domain.
|
59 |
+
"""
|
60 |
+
self.real = input['A'].to(self.device)
|
61 |
+
self.image_paths = input['A_paths']
|
62 |
+
|
63 |
+
def forward(self):
|
64 |
+
"""Run forward pass."""
|
65 |
+
self.fake = self.netG(self.real) # G(real)
|
66 |
+
|
67 |
+
def optimize_parameters(self):
|
68 |
+
"""No optimization for test model."""
|
69 |
+
pass
|