|
import argparse |
|
|
|
import os |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from skimage import color, io |
|
import torch |
|
from torch import nn, optim |
|
from torch.nn import functional as F |
|
from torch.utils import data |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
from torch.autograd import Variable |
|
|
|
|
|
from models import ColorEncoder, ColorUNet |
|
from discriminator import Discriminator |
|
from data.data_loader import MultiResolutionDataset |
|
|
|
from utils import tensor_lab2rgb |
|
|
|
from distributed import ( |
|
get_rank, |
|
synchronize, |
|
reduce_loss_dict, |
|
) |
|
|
|
|
|
def mkdirss(dirpath): |
|
if not os.path.exists(dirpath): |
|
os.makedirs(dirpath) |
|
|
|
|
|
def data_sampler(dataset, shuffle, distributed): |
|
if distributed: |
|
return data.distributed.DistributedSampler(dataset, shuffle=shuffle) |
|
|
|
if shuffle: |
|
return data.RandomSampler(dataset) |
|
|
|
else: |
|
return data.SequentialSampler(dataset) |
|
|
|
|
|
def requires_grad(model, flag=True): |
|
for p in model.parameters(): |
|
p.requires_grad = flag |
|
|
|
|
|
def sample_data(loader): |
|
while True: |
|
for batch in loader: |
|
yield batch |
|
|
|
|
|
def Lab2RGB_out(img_lab): |
|
img_lab = img_lab.detach().cpu() |
|
img_l = img_lab[:, :1, :, :] |
|
img_ab = img_lab[:, 1:, :, :] |
|
|
|
|
|
img_l = img_l + 50 |
|
pred_lab = torch.cat((img_l, img_ab), 1)[0, ...].numpy() |
|
|
|
|
|
out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8") |
|
return out |
|
|
|
|
|
def RGB2Lab(inputs): |
|
|
|
|
|
return color.rgb2lab(inputs) |
|
|
|
|
|
def Normalize(inputs): |
|
l = inputs[:, :, 0:1] |
|
ab = inputs[:, :, 1:3] |
|
l = l - 50 |
|
lab = np.concatenate((l, ab), 2) |
|
|
|
return lab.astype('float32') |
|
|
|
|
|
def numpy2tensor(inputs): |
|
out = torch.from_numpy(inputs.transpose(2, 0, 1)) |
|
return out |
|
|
|
|
|
def tensor2numpy(inputs): |
|
out = inputs[0, ...].detach().cpu().numpy().transpose(1, 2, 0) |
|
return out |
|
|
|
|
|
def preprocessing(inputs): |
|
|
|
img_lab = Normalize(RGB2Lab(inputs)) |
|
img = np.array(inputs, 'float32') |
|
img = numpy2tensor(img) |
|
img_lab = numpy2tensor(img_lab) |
|
return img.unsqueeze(0), img_lab.unsqueeze(0) |
|
|
|
|
|
def uncenter_l(inputs): |
|
l = inputs[:, :1, :, :] + 50 |
|
ab = inputs[:, 1:, :, :] |
|
return torch.cat((l, ab), 1) |
|
|
|
|
|
def train( |
|
args, |
|
loader, |
|
colorEncoder, |
|
colorUNet, |
|
discriminator, |
|
d_optim, |
|
device, |
|
): |
|
loader = sample_data(loader) |
|
|
|
pbar = range(args.iter) |
|
|
|
if get_rank() == 0: |
|
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01) |
|
|
|
disc_val_all = 0 |
|
criterion_GAN = torch.nn.MSELoss().to(device) |
|
|
|
|
|
patch = (1, args.size // 2 ** 4, args.size // 2 ** 4) |
|
Tensor = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor |
|
|
|
for idx in pbar: |
|
i = idx + args.start_iter |
|
|
|
if i > args.iter: |
|
print("Done!") |
|
|
|
break |
|
|
|
img, img_ref, img_lab = next(loader) |
|
|
|
|
|
valid = Variable(Tensor(np.ones((img.size(0), *patch))), requires_grad=False) |
|
fake = Variable(Tensor(np.zeros((img.size(0), *patch))), requires_grad=False) |
|
|
|
|
|
|
|
|
|
|
|
img = img.to(device) |
|
img_lab = img_lab.to(device) |
|
|
|
img_ref = img_ref.to(device) |
|
|
|
img_l = img_lab[:, :1, :, :] / 50 |
|
img_ab = img_lab[:, 1:, :, :] / 110 |
|
|
|
|
|
colorEncoder.eval() |
|
colorUNet.eval() |
|
discriminator.train() |
|
|
|
requires_grad(colorEncoder, False) |
|
requires_grad(colorUNet, False) |
|
requires_grad(discriminator, True) |
|
|
|
with torch.no_grad(): |
|
ref_color_vector = colorEncoder(img_ref / 255.) |
|
fake_swap_ab = colorUNet((img_l, ref_color_vector)) |
|
|
|
fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, fake_swap_ab * 110), 1)) |
|
real_img_rgb = img / 255. |
|
img_ref_rgb = img_ref / 255. |
|
|
|
zero_ab_image = torch.zeros_like(fake_swap_ab) |
|
input_img_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, zero_ab_image), 1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_real = discriminator(real_img_rgb, input_img_rgb, img_ref_rgb) |
|
loss_real = criterion_GAN(pred_real, valid) |
|
|
|
|
|
pred_fake = discriminator(fake_swap_rgb.detach(), input_img_rgb, img_ref_rgb) |
|
loss_fake = criterion_GAN(pred_fake, fake) |
|
|
|
|
|
disc_loss = 0.5 * (loss_real + loss_fake) |
|
|
|
d_optim.zero_grad() |
|
disc_loss.backward() |
|
d_optim.step() |
|
|
|
disc_val = disc_loss.mean().item() |
|
disc_val_all += disc_val |
|
|
|
if get_rank() == 0: |
|
pbar.set_description( |
|
( |
|
f"discriminator:{disc_val:.4f};" |
|
) |
|
) |
|
|
|
if i % 100 == 0: |
|
print(f"discriminator:{disc_val_all / 100:.4f};") |
|
disc_val_all = 0 |
|
if i % 1000 == 0: |
|
out_dir = "experiments/%s" % (args.experiment_name) |
|
mkdirss(out_dir) |
|
torch.save( |
|
{ |
|
"discriminator": discriminator.state_dict(), |
|
"d_optim": d_optim.state_dict(), |
|
"args": args, |
|
}, |
|
f"%s/{str(i).zfill(6)}_ds.pt" % (out_dir), |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
device = "cuda" |
|
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--datasets", type=str) |
|
parser.add_argument("--iter", type=int, default=100000) |
|
parser.add_argument("--batch", type=int, default=16) |
|
parser.add_argument("--size", type=int, default=256) |
|
parser.add_argument("--ckpt", type=str, default=None) |
|
parser.add_argument("--ckpt_disc", type=str, default=None) |
|
parser.add_argument("--lr", type=float, default=0.0002) |
|
parser.add_argument("--experiment_name", type=str, default="default") |
|
parser.add_argument("--wandb", action="store_true") |
|
parser.add_argument("--local_rank", type=int, default=0) |
|
|
|
args = parser.parse_args() |
|
|
|
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 |
|
args.distributed = n_gpu > 1 |
|
|
|
if args.distributed: |
|
torch.cuda.set_device(args.local_rank) |
|
torch.distributed.init_process_group(backend="nccl", init_method="env://") |
|
synchronize() |
|
|
|
args.start_iter = 0 |
|
|
|
colorEncoder = ColorEncoder(color_dim=512).to(device) |
|
colorUNet = ColorUNet(bilinear=True).to(device) |
|
discriminator = Discriminator(in_channels=3).to(device) |
|
|
|
d_optim = optim.Adam( |
|
discriminator.parameters(), |
|
lr=args.lr, |
|
betas=(0.5, 0.999), |
|
) |
|
|
|
if args.ckpt is not None: |
|
print("load model:", args.ckpt) |
|
|
|
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) |
|
|
|
colorEncoder.load_state_dict(ckpt["colorEncoder"]) |
|
colorUNet.load_state_dict(ckpt["colorUNet"]) |
|
|
|
if args.ckpt_disc is not None: |
|
print("load discriminator model:", args.ckpt_disc) |
|
|
|
ckpt_disc = torch.load(args.ckpt_disc, map_location=lambda storage, loc: storage) |
|
|
|
try: |
|
ckpt_name = os.path.basename(args.ckpt_disc) |
|
args.start_iter = int(os.path.splitext(ckpt_name)[0]) |
|
|
|
except ValueError: |
|
pass |
|
|
|
discriminator.load_state_dict(ckpt_disc["discriminator"]) |
|
d_optim.load_state_dict(ckpt_disc["d_optim"]) |
|
|
|
|
|
|
|
if args.distributed: |
|
colorEncoder = nn.parallel.DistributedDataParallel( |
|
colorEncoder, |
|
device_ids=[args.local_rank], |
|
output_device=args.local_rank, |
|
broadcast_buffers=False, |
|
) |
|
|
|
colorUNet = nn.parallel.DistributedDataParallel( |
|
colorUNet, |
|
device_ids=[args.local_rank], |
|
output_device=args.local_rank, |
|
broadcast_buffers=False, |
|
) |
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.RandomHorizontalFlip(), |
|
transforms.RandomVerticalFlip(), |
|
transforms.RandomRotation(degrees=(0, 360)) |
|
] |
|
) |
|
|
|
datasets = [] |
|
dataset = MultiResolutionDataset(args.datasets, transform, args.size) |
|
datasets.append(dataset) |
|
|
|
loader = data.DataLoader( |
|
data.ConcatDataset(datasets), |
|
batch_size=args.batch, |
|
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed), |
|
drop_last=True, |
|
) |
|
|
|
train( |
|
args, |
|
loader, |
|
colorEncoder, |
|
colorUNet, |
|
discriminator, |
|
d_optim, |
|
device, |
|
) |
|
|