Keiser41's picture
Upload 98 files
22d8ab7
import argparse
import os
import numpy as np
from PIL import Image
from skimage import color, io
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm
from torch.autograd import Variable
# from ColorEncoder import ColorEncoder
from models import ColorEncoder, ColorUNet
from discriminator import Discriminator
from data.data_loader import MultiResolutionDataset
from utils import tensor_lab2rgb
from distributed import (
get_rank,
synchronize,
reduce_loss_dict,
)
def mkdirss(dirpath):
if not os.path.exists(dirpath):
os.makedirs(dirpath)
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def sample_data(loader):
while True:
for batch in loader:
yield batch
def Lab2RGB_out(img_lab):
img_lab = img_lab.detach().cpu()
img_l = img_lab[:, :1, :, :]
img_ab = img_lab[:, 1:, :, :]
# print(torch.max(img_l), torch.min(img_l))
# print(torch.max(img_ab), torch.min(img_ab))
img_l = img_l + 50
pred_lab = torch.cat((img_l, img_ab), 1)[0, ...].numpy()
# grid_lab = utils.make_grid(pred_lab, nrow=1).numpy().astype("float64")
# print(grid_lab.shape)
out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1) * 255).astype("uint8")
return out
def RGB2Lab(inputs):
# input [0, 255] uint8
# out l: [0, 100], ab: [-110, 110], float32
return color.rgb2lab(inputs)
def Normalize(inputs):
l = inputs[:, :, 0:1]
ab = inputs[:, :, 1:3]
l = l - 50
lab = np.concatenate((l, ab), 2)
return lab.astype('float32')
def numpy2tensor(inputs):
out = torch.from_numpy(inputs.transpose(2, 0, 1))
return out
def tensor2numpy(inputs):
out = inputs[0, ...].detach().cpu().numpy().transpose(1, 2, 0)
return out
def preprocessing(inputs):
# input: rgb, [0, 255], uint8
img_lab = Normalize(RGB2Lab(inputs))
img = np.array(inputs, 'float32') # [0, 255]
img = numpy2tensor(img)
img_lab = numpy2tensor(img_lab)
return img.unsqueeze(0), img_lab.unsqueeze(0)
def uncenter_l(inputs):
l = inputs[:, :1, :, :] + 50
ab = inputs[:, 1:, :, :]
return torch.cat((l, ab), 1)
def train(
args,
loader,
colorEncoder,
colorUNet,
discriminator,
d_optim,
device,
):
loader = sample_data(loader)
pbar = range(args.iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
disc_val_all = 0
criterion_GAN = torch.nn.MSELoss().to(device)
# Calculate output of image discriminator (PatchGAN)
patch = (1, args.size // 2 ** 4, args.size // 2 ** 4)
Tensor = torch.cuda.FloatTensor if device == 'cuda' else torch.FloatTensor
for idx in pbar:
i = idx + args.start_iter
if i > args.iter:
print("Done!")
break
img, img_ref, img_lab = next(loader)
# Adversarial ground truths
valid = Variable(Tensor(np.ones((img.size(0), *patch))), requires_grad=False)
fake = Variable(Tensor(np.zeros((img.size(0), *patch))), requires_grad=False)
# ima = img.numpy()
# ima = ima[0].astype('uint8')
# ima = Image.fromarray(ima.transpose(1,2,0))
# ima.show()
img = img.to(device) # GT [B, 3, 256, 256]
img_lab = img_lab.to(device) # GT
img_ref = img_ref.to(device) # tps_transformed image RGB [B, 3, 256, 256]
img_l = img_lab[:, :1, :, :] / 50 # [-1, 1] target L
img_ab = img_lab[:, 1:, :, :] / 110 # [-1, 1] target ab
# img_ref_ab = img_ref_lab[:,1:,:,:] / 110 # [-1, 1] ref ab
colorEncoder.eval()
colorUNet.eval()
discriminator.train()
requires_grad(colorEncoder, False)
requires_grad(colorUNet, False)
requires_grad(discriminator, True)
with torch.no_grad():
ref_color_vector = colorEncoder(img_ref / 255.)
fake_swap_ab = colorUNet((img_l, ref_color_vector)) # [-1, 1]
fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, fake_swap_ab * 110), 1)) # [0, 1]
real_img_rgb = img / 255.
img_ref_rgb = img_ref / 255.
zero_ab_image = torch.zeros_like(fake_swap_ab)
input_img_rgb = tensor_lab2rgb(torch.cat((img_l * 50 + 50, zero_ab_image), 1)) # [0, 1]
# show the gray image
# input_img_rgb_cpu = input_img_rgb.cpu()
# ima = input_img_rgb_cpu.numpy()
# ima = ima*255
# ima = ima[0].astype('uint8')
# ima = Image.fromarray(ima.transpose(1,2,0))
# ima.show()
# Real loss
pred_real = discriminator(real_img_rgb, input_img_rgb, img_ref_rgb)
loss_real = criterion_GAN(pred_real, valid)
# Fake loss
pred_fake = discriminator(fake_swap_rgb.detach(), input_img_rgb, img_ref_rgb)
loss_fake = criterion_GAN(pred_fake, fake)
# Total loss
disc_loss = 0.5 * (loss_real + loss_fake)
d_optim.zero_grad()
disc_loss.backward()
d_optim.step()
disc_val = disc_loss.mean().item()
disc_val_all += disc_val
if get_rank() == 0:
pbar.set_description(
(
f"discriminator:{disc_val:.4f};"
)
)
if i % 100 == 0:
print(f"discriminator:{disc_val_all / 100:.4f};")
disc_val_all = 0
if i % 1000 == 0:
out_dir = "experiments/%s" % (args.experiment_name)
mkdirss(out_dir)
torch.save(
{
"discriminator": discriminator.state_dict(),
"d_optim": d_optim.state_dict(),
"args": args,
},
f"%s/{str(i).zfill(6)}_ds.pt" % (out_dir),
)
if __name__ == "__main__":
device = "cuda"
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument("--datasets", type=str)
parser.add_argument("--iter", type=int, default=100000)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--size", type=int, default=256)
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--ckpt_disc", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.0002)
parser.add_argument("--experiment_name", type=str, default="default")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = n_gpu > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
args.start_iter = 0
colorEncoder = ColorEncoder(color_dim=512).to(device)
colorUNet = ColorUNet(bilinear=True).to(device)
discriminator = Discriminator(in_channels=3).to(device)
d_optim = optim.Adam(
discriminator.parameters(),
lr=args.lr,
betas=(0.5, 0.999),
)
if args.ckpt is not None:
print("load model:", args.ckpt)
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
colorEncoder.load_state_dict(ckpt["colorEncoder"])
colorUNet.load_state_dict(ckpt["colorUNet"])
if args.ckpt_disc is not None:
print("load discriminator model:", args.ckpt_disc)
ckpt_disc = torch.load(args.ckpt_disc, map_location=lambda storage, loc: storage)
try:
ckpt_name = os.path.basename(args.ckpt_disc)
args.start_iter = int(os.path.splitext(ckpt_name)[0])
except ValueError:
pass
discriminator.load_state_dict(ckpt_disc["discriminator"])
d_optim.load_state_dict(ckpt_disc["d_optim"])
# print(args.distributed)
if args.distributed:
colorEncoder = nn.parallel.DistributedDataParallel(
colorEncoder,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
colorUNet = nn.parallel.DistributedDataParallel(
colorUNet,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(degrees=(0, 360))
]
)
datasets = []
dataset = MultiResolutionDataset(args.datasets, transform, args.size)
datasets.append(dataset)
loader = data.DataLoader(
data.ConcatDataset(datasets),
batch_size=args.batch,
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
drop_last=True,
)
train(
args,
loader,
colorEncoder,
colorUNet,
discriminator,
d_optim,
device,
)