Keiser41's picture
Upload 98 files
22d8ab7
raw
history blame
11.4 kB
import argparse
import os
import numpy as np
from PIL import Image
from skimage import color, io
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils import data
from torchvision import transforms
from tqdm import tqdm
# from ColorEncoder import ColorEncoder
from models import ColorEncoder, ColorUNet
from vgg_model import vgg19
from data.data_loader import MultiResolutionDataset
from utils import tensor_lab2rgb
from distributed import (
get_rank,
synchronize,
reduce_loss_dict,
)
def mkdirss(dirpath):
if not os.path.exists(dirpath):
os.makedirs(dirpath)
def data_sampler(dataset, shuffle, distributed):
if distributed:
return data.distributed.DistributedSampler(dataset, shuffle=shuffle)
if shuffle:
return data.RandomSampler(dataset)
else:
return data.SequentialSampler(dataset)
def requires_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
def sample_data(loader):
while True:
for batch in loader:
yield batch
def Lab2RGB_out(img_lab):
img_lab = img_lab.detach().cpu()
img_l = img_lab[:,:1,:,:]
img_ab = img_lab[:,1:,:,:]
# print(torch.max(img_l), torch.min(img_l))
# print(torch.max(img_ab), torch.min(img_ab))
img_l = img_l + 50
pred_lab = torch.cat((img_l, img_ab), 1)[0,...].numpy()
# grid_lab = utils.make_grid(pred_lab, nrow=1).numpy().astype("float64")
# print(grid_lab.shape)
out = (np.clip(color.lab2rgb(pred_lab.transpose(1, 2, 0)), 0, 1)* 255).astype("uint8")
return out
def RGB2Lab(inputs):
# input [0, 255] uint8
# out l: [0, 100], ab: [-110, 110], float32
return color.rgb2lab(inputs)
def Normalize(inputs):
l = inputs[:, :, 0:1]
ab = inputs[:, :, 1:3]
l = l - 50
lab = np.concatenate((l, ab), 2)
return lab.astype('float32')
def numpy2tensor(inputs):
out = torch.from_numpy(inputs.transpose(2,0,1))
return out
def tensor2numpy(inputs):
out = inputs[0,...].detach().cpu().numpy().transpose(1,2,0)
return out
def preprocessing(inputs):
# input: rgb, [0, 255], uint8
img_lab = Normalize(RGB2Lab(inputs))
img = np.array(inputs, 'float32') # [0, 255]
img = numpy2tensor(img)
img_lab = numpy2tensor(img_lab)
return img.unsqueeze(0), img_lab.unsqueeze(0)
def uncenter_l(inputs):
l = inputs[:,:1,:,:] + 50
ab = inputs[:,1:,:,:]
return torch.cat((l, ab), 1)
def train(
args,
loader,
colorEncoder,
colorUNet,
vggnet,
g_optim,
device,
):
loader = sample_data(loader)
pbar = range(args.iter)
if get_rank() == 0:
pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)
g_loss_val = 0
loss_dict = {}
recon_val_all = 0
fea_val_all = 0
if args.distributed:
colorEncoder_module = colorEncoder.module
colorUNet_module = colorUNet.module
else:
colorEncoder_module = colorEncoder
colorUNet_module = colorUNet
for idx in pbar:
i = idx + args.start_iter+1
if i > args.iter:
print("Done!")
break
img, img_ref, img_lab = next(loader)
# ima = img_ref.numpy()
# ima = ima[0].astype('uint8')
# ima = Image.fromarray(ima.transpose(1,2,0))
# ima.show()
img = img.to(device) # GT [B, 3, 256, 256]
img_lab = img_lab.to(device) # GT
img_ref = img_ref.to(device) # tps_transformed image RGB [B, 3, 256, 256]
img_l = img_lab[:,:1,:,:] / 50 # [-1, 1] target L
img_ab = img_lab[:,1:,:,:] / 110 # [-1, 1] target ab
# img_ref_ab = img_ref_lab[:,1:,:,:] / 110 # [-1, 1] ref ab
colorEncoder.train()
colorUNet.train()
requires_grad(colorEncoder, True)
requires_grad(colorUNet, True)
ref_color_vector = colorEncoder(img_ref / 255.)
fake_swap_ab = colorUNet((img_l, ref_color_vector)) # [-1, 1]
## recon l1 loss
recon_loss = (F.smooth_l1_loss(fake_swap_ab, img_ab)) * 1
## feature loss
real_img_rgb = img / 255.
features_A = vggnet(real_img_rgb, layer_name='all')
fake_swap_rgb = tensor_lab2rgb(torch.cat((img_l*50+50, fake_swap_ab*110), 1)) # [0, 1]
features_B = vggnet(fake_swap_rgb, layer_name='all')
# fea_loss = F.l1_loss(features_A[-1], features_B[-1]) * 0.1
# fea_loss = 0
fea_loss1 = F.l1_loss(features_A[0], features_B[0]) / 32 * 0.1
fea_loss2 = F.l1_loss(features_A[1], features_B[1]) / 16 * 0.1
fea_loss3 = F.l1_loss(features_A[2], features_B[2]) / 8 * 0.1
fea_loss4 = F.l1_loss(features_A[3], features_B[3]) / 4 * 0.1
fea_loss5 = F.l1_loss(features_A[4], features_B[4]) * 0.1
fea_loss = fea_loss1 + fea_loss2 + fea_loss3 + fea_loss4 + fea_loss5
loss_dict["recon"] = recon_loss
loss_dict["fea"] = fea_loss
g_optim.zero_grad()
(recon_loss+fea_loss).backward()
g_optim.step()
loss_reduced = reduce_loss_dict(loss_dict)
recon_val = loss_reduced["recon"].mean().item()
recon_val_all += recon_val
# recon_val = 0
fea_val = loss_reduced["fea"].mean().item()
fea_val_all += fea_val
# fea_val = 0
if get_rank() == 0:
pbar.set_description(
(
f"recon:{recon_val:.4f}; fea:{fea_val:.4f};"
)
)
if i % 50 == 0:
print(f"recon_all:{recon_val_all/50:.4f}; fea_all:{fea_val_all/50:.4f};")
recon_val_all = 0
fea_val_all = 0
if i % 500 == 0:
with torch.no_grad():
colorEncoder.eval()
colorUNet.eval()
imgsize = 256
for inum in range(15):
val_img_path = 'test_datasets/val_Manga/in%d.jpg' % (inum + 1)
val_ref_path = 'test_datasets/val_Manga/ref%d.jpg' % (inum + 1)
# val_img_path = 'test_datasets/val_daytime/day_sample/in%d.jpg'%(inum+1)
# val_ref_path = 'test_datasets/val_daytime/night_sample/dark4.jpg'
out_name = 'in%d_ref%d.png'%(inum+1, inum+1)
val_img = Image.open(val_img_path).convert("RGB").resize((imgsize, imgsize))
val_img_ref = Image.open(val_ref_path).convert("RGB").resize((imgsize, imgsize))
val_img, val_img_lab = preprocessing(val_img)
val_img_ref, val_img_ref_lab = preprocessing(val_img_ref)
# val_img = val_img.to(device)
val_img_lab = val_img_lab.to(device)
val_img_ref = val_img_ref.to(device)
# val_img_ref_lab = val_img_ref_lab.to(device)
val_img_l = val_img_lab[:,:1,:,:] / 50. # [-1, 1]
# val_img_ref_ab = val_img_ref_lab[:,1:,:,:] / 110. # [-1, 1]
ref_color_vector = colorEncoder(val_img_ref / 255.) # [0, 1]
fake_swap_ab = colorUNet((val_img_l, ref_color_vector))
fake_img = torch.cat((val_img_l*50, fake_swap_ab*110), 1)
sample = np.concatenate((tensor2numpy(val_img), tensor2numpy(val_img_ref), Lab2RGB_out(fake_img)), 1)
out_dir = 'training_logs/%s/%06d'%(args.experiment_name, i)
mkdirss(out_dir)
io.imsave('%s/%s'%(out_dir, out_name), sample.astype('uint8'))
torch.cuda.empty_cache()
if i % 2500 == 0:
out_dir = "experiments/%s"%(args.experiment_name)
mkdirss(out_dir)
torch.save(
{
"colorEncoder": colorEncoder_module.state_dict(),
"colorUNet": colorUNet_module.state_dict(),
"g_optim": g_optim.state_dict(),
"args": args,
},
f"%s/{str(i).zfill(6)}.pt"%(out_dir),
)
if __name__ == "__main__":
device = "cuda"
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument("--datasets", type=str)
parser.add_argument("--iter", type=int, default=100000)
parser.add_argument("--batch", type=int, default=16)
parser.add_argument("--size", type=int, default=256)
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--lr", type=float, default=0.0001)
parser.add_argument("--experiment_name", type=str, default="default")
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--local_rank", type=int, default=0)
args = parser.parse_args()
n_gpu = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
args.distributed = n_gpu > 1
if args.distributed:
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
synchronize()
args.start_iter = 0
vggnet = vgg19(pretrained_path = './experiments/VGG19/vgg19-dcbb9e9d.pth', require_grad = False)
vggnet = vggnet.to(device)
vggnet.eval()
colorEncoder = ColorEncoder(color_dim=512).to(device)
colorUNet = ColorUNet(bilinear=True).to(device)
g_optim = optim.Adam(
list(colorEncoder.parameters()) + list(colorUNet.parameters()),
lr=args.lr,
betas=(0.9, 0.99),
)
if args.ckpt is not None:
print("load model:", args.ckpt)
ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage)
try:
ckpt_name = os.path.basename(args.ckpt)
args.start_iter = int(os.path.splitext(ckpt_name)[0])
except ValueError:
pass
colorEncoder.load_state_dict(ckpt["colorEncoder"])
colorUNet.load_state_dict(ckpt["colorUNet"])
g_optim.load_state_dict(ckpt["g_optim"])
# print(args.distributed)
if args.distributed:
colorEncoder = nn.parallel.DistributedDataParallel(
colorEncoder,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
colorUNet = nn.parallel.DistributedDataParallel(
colorUNet,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False,
)
transform = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(degrees=(0, 360))
]
)
datasets = []
dataset = MultiResolutionDataset(args.datasets, transform, args.size)
datasets.append(dataset)
loader = data.DataLoader(
data.ConcatDataset(datasets),
batch_size=args.batch,
sampler=data_sampler(dataset, shuffle=True, distributed=args.distributed),
drop_last=True,
)
train(
args,
loader,
colorEncoder,
colorUNet,
vggnet,
g_optim,
device,
)