pytorchAnimeGAN / trainer /__init__.py
ptran1203's picture
first
f2fa83b
import os
import time
import shutil
import torch
import cv2
import torch.optim as optim
import numpy as np
from glob import glob
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from utils.image_processing import denormalize_input, preprocess_images, resize_image
from losses import LossSummary, AnimeGanLoss, to_gray_scale
from utils import load_checkpoint, save_checkpoint, read_image
from utils.common import set_lr
from color_transfer import color_transfer_pytorch
def transfer_color_and_rescale(src, target):
"""Transfer color from src image to target then rescale to [-1, 1]"""
out = color_transfer_pytorch(src, target) # [0, 1]
out = (out / 0.5) - 1
return out
def gaussian_noise():
gaussian_mean = torch.tensor(0.0)
gaussian_std = torch.tensor(0.1)
return torch.normal(gaussian_mean, gaussian_std)
def convert_to_readable(seconds):
return time.strftime('%H:%M:%S', time.gmtime(seconds))
def revert_to_np_image(image_tensor):
image = image_tensor.cpu().numpy()
# CHW
image = image.transpose(1, 2, 0)
image = denormalize_input(image, dtype=np.int16)
return image[..., ::-1] # to RGB
def save_generated_images(images: torch.Tensor, save_dir: str):
"""Save generated images `(*, 3, H, W)` range [-1, 1] into disk"""
os.makedirs(save_dir, exist_ok=True)
images = images.clone().detach().cpu().numpy()
images = images.transpose(0, 2, 3, 1)
n_images = len(images)
for i in range(n_images):
img = images[i]
img = denormalize_input(img, dtype=np.int16)
img = img[..., ::-1]
cv2.imwrite(os.path.join(save_dir, f"G{i}.jpg"), img)
class DDPTrainer:
def _init_distributed(self):
if self.cfg.ddp:
self.logger.info("Setting up DDP")
self.pg = torch.distributed.init_process_group(
backend="nccl",
rank=self.cfg.local_rank,
world_size=self.cfg.world_size
)
self.G = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.G, self.pg)
self.D = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.D, self.pg)
torch.cuda.set_device(self.cfg.local_rank)
self.G.cuda(self.cfg.local_rank)
self.D.cuda(self.cfg.local_rank)
self.logger.info("Setting up DDP Done")
def _init_amp(self, enabled=False):
# self.scaler = torch.cuda.amp.GradScaler(enabled=enabled, growth_interval=100)
self.scaler_g = GradScaler(enabled=enabled)
self.scaler_d = GradScaler(enabled=enabled)
if self.cfg.ddp:
self.G = DistributedDataParallel(
self.G, device_ids=[self.cfg.local_rank],
output_device=self.cfg.local_rank,
find_unused_parameters=False)
self.D = DistributedDataParallel(
self.D, device_ids=[self.cfg.local_rank],
output_device=self.cfg.local_rank,
find_unused_parameters=False)
self.logger.info("Set DistributedDataParallel")
class Trainer(DDPTrainer):
"""
Base Trainer class
"""
def __init__(
self,
generator,
discriminator,
config,
logger,
) -> None:
self.G = generator
self.D = discriminator
self.cfg = config
self.max_norm = 10
self.device_type = 'cuda' if self.cfg.device.startswith('cuda') else 'cpu'
self.optimizer_g = optim.Adam(self.G.parameters(), lr=self.cfg.lr_g, betas=(0.5, 0.999))
self.optimizer_d = optim.Adam(self.D.parameters(), lr=self.cfg.lr_d, betas=(0.5, 0.999))
self.loss_tracker = LossSummary()
if self.cfg.ddp:
self.device = torch.device(f"cuda:{self.cfg.local_rank}")
logger.info(f"---------{self.cfg.local_rank} {self.device}")
else:
self.device = torch.device(self.cfg.device)
self.loss_fn = AnimeGanLoss(self.cfg, self.device, self.cfg.gray_adv)
self.logger = logger
self._init_working_dir()
self._init_distributed()
self._init_amp(enabled=self.cfg.amp)
def _init_working_dir(self):
"""Init working directory for saving checkpoint, ..."""
os.makedirs(self.cfg.exp_dir, exist_ok=True)
Gname = self.G.name
Dname = self.D.name
self.checkpoint_path_G_init = os.path.join(self.cfg.exp_dir, f"{Gname}_init.pt")
self.checkpoint_path_G = os.path.join(self.cfg.exp_dir, f"{Gname}.pt")
self.checkpoint_path_D = os.path.join(self.cfg.exp_dir, f"{Dname}.pt")
self.save_image_dir = os.path.join(self.cfg.exp_dir, "generated_images")
self.example_image_dir = os.path.join(self.cfg.exp_dir, "train_images")
os.makedirs(self.save_image_dir, exist_ok=True)
os.makedirs(self.example_image_dir, exist_ok=True)
def init_weight_G(self, weight: str):
"""Init Generator weight"""
return load_checkpoint(self.G, weight)
def init_weight_D(self, weight: str):
"""Init Discriminator weight"""
return load_checkpoint(self.D, weight)
def pretrain_generator(self, train_loader, start_epoch):
"""
Pretrain Generator to recontruct input image.
"""
init_losses = []
set_lr(self.optimizer_g, self.cfg.init_lr)
for epoch in range(start_epoch, self.cfg.init_epochs):
# Train with content loss only
pbar = tqdm(train_loader)
for data in pbar:
img = data["image"].to(self.device)
self.optimizer_g.zero_grad()
with autocast(enabled=self.cfg.amp):
fake_img = self.G(img)
loss = self.loss_fn.content_loss_vgg(img, fake_img)
self.scaler_g.scale(loss).backward()
self.scaler_g.step(self.optimizer_g)
self.scaler_g.update()
if self.cfg.ddp:
torch.distributed.barrier()
init_losses.append(loss.cpu().detach().numpy())
avg_content_loss = sum(init_losses) / len(init_losses)
pbar.set_description(f'[Init Training G] content loss: {avg_content_loss:2f}')
save_checkpoint(self.G, self.checkpoint_path_G_init, self.optimizer_g, epoch)
if self.cfg.local_rank == 0:
self.generate_and_save(self.cfg.test_image_dir, subname='initg')
self.logger.info(f"Epoch {epoch}/{self.cfg.init_epochs}")
set_lr(self.optimizer_g, self.cfg.lr_g)
def train_epoch(self, epoch, train_loader):
pbar = tqdm(train_loader, total=len(train_loader))
for data in pbar:
img = data["image"].to(self.device)
anime = data["anime"].to(self.device)
anime_gray = data["anime_gray"].to(self.device)
anime_smt_gray = data["smooth_gray"].to(self.device)
# ---------------- TRAIN D ---------------- #
self.optimizer_d.zero_grad()
with autocast(enabled=self.cfg.amp):
fake_img = self.G(img)
# Add some Gaussian noise to images before feeding to D
if self.cfg.d_noise:
fake_img += gaussian_noise()
anime += gaussian_noise()
anime_gray += gaussian_noise()
anime_smt_gray += gaussian_noise()
if self.cfg.gray_adv:
fake_img = to_gray_scale(fake_img)
fake_d = self.D(fake_img)
real_anime_d = self.D(anime)
real_anime_gray_d = self.D(anime_gray)
real_anime_smt_gray_d = self.D(anime_smt_gray)
loss_d = self.loss_fn.compute_loss_D(
fake_d,
real_anime_d,
real_anime_gray_d,
real_anime_smt_gray_d
)
self.scaler_d.scale(loss_d).backward()
self.scaler_d.unscale_(self.optimizer_d)
torch.nn.utils.clip_grad_norm_(self.D.parameters(), max_norm=self.max_norm)
self.scaler_d.step(self.optimizer_d)
self.scaler_d.update()
if self.cfg.ddp:
torch.distributed.barrier()
self.loss_tracker.update_loss_D(loss_d)
# ---------------- TRAIN G ---------------- #
self.optimizer_g.zero_grad()
with autocast(enabled=self.cfg.amp):
fake_img = self.G(img)
if self.cfg.gray_adv:
fake_d = self.D(to_gray_scale(fake_img))
else:
fake_d = self.D(fake_img)
(
adv_loss, con_loss,
gra_loss, col_loss,
tv_loss
) = self.loss_fn.compute_loss_G(
fake_img,
img,
fake_d,
anime_gray,
)
loss_g = adv_loss + con_loss + gra_loss + col_loss + tv_loss
if torch.isnan(adv_loss).any():
self.logger.info("----------------------------------------------")
self.logger.info(fake_d)
self.logger.info(adv_loss)
self.logger.info("----------------------------------------------")
raise ValueError("NAN loss!!")
self.scaler_g.scale(loss_g).backward()
self.scaler_d.unscale_(self.optimizer_g)
grad = torch.nn.utils.clip_grad_norm_(self.G.parameters(), max_norm=self.max_norm)
self.scaler_g.step(self.optimizer_g)
self.scaler_g.update()
if self.cfg.ddp:
torch.distributed.barrier()
self.loss_tracker.update_loss_G(adv_loss, gra_loss, col_loss, con_loss)
pbar.set_description(f"{self.loss_tracker.get_loss_description()} - {grad:.3f}")
def get_train_loader(self, dataset):
if self.cfg.ddp:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
train_sampler = None
return DataLoader(
dataset,
batch_size=self.cfg.batch_size,
num_workers=self.cfg.num_workers,
pin_memory=True,
shuffle=train_sampler is None,
sampler=train_sampler,
drop_last=True,
# collate_fn=collate_fn,
)
def maybe_increase_imgsz(self, epoch, train_dataset):
"""
Increase image size at specific epoch
+ 50% epochs train at imgsz[0]
+ the rest 50% will increase every `len(epochs) / 2 / (len(imgsz) - 1)`
Args:
epoch: Current epoch
train_dataset: Dataset
Examples:
```
epochs = 100
imgsz = [256, 352, 416, 512]
=> [(0, 256), (50, 352), (66, 416), (82, 512)]
```
"""
epochs = self.cfg.epochs
imgsz = self.cfg.imgsz
num_size_remains = len(imgsz) - 1
half_epochs = epochs // 2
if len(imgsz) == 1:
new_size = imgsz[0]
elif epoch < half_epochs:
new_size = imgsz[0]
else:
per_epoch_increment = int(half_epochs / num_size_remains)
found = None
for i, size in enumerate(imgsz[:]):
if epoch < half_epochs + per_epoch_increment * i:
found = size
break
if not found:
found = imgsz[-1]
new_size = found
self.logger.info(f"Check {imgsz}, {new_size}, {train_dataset.imgsz}")
if new_size != train_dataset.imgsz:
train_dataset.set_imgsz(new_size)
self.logger.info(f"Increase image size to {new_size} at epoch {epoch}")
def train(self, train_dataset: Dataset, start_epoch=0, start_epoch_g=0):
"""
Train Generator and Discriminator.
"""
self.logger.info(self.device)
self.G.to(self.device)
self.D.to(self.device)
self.pretrain_generator(self.get_train_loader(train_dataset), start_epoch_g)
if self.cfg.local_rank == 0:
self.logger.info(f"Start training for {self.cfg.epochs} epochs")
for i, data in enumerate(train_dataset):
for k in data.keys():
image = data[k]
cv2.imwrite(
os.path.join(self.example_image_dir, f"data_{k}_{i}.jpg"),
revert_to_np_image(image)
)
if i == 2:
break
end = None
num_iter = 0
per_epoch_times = []
for epoch in range(start_epoch, self.cfg.epochs):
self.maybe_increase_imgsz(epoch, train_dataset)
start = time.time()
self.train_epoch(epoch, self.get_train_loader(train_dataset))
if epoch % self.cfg.save_interval == 0 and self.cfg.local_rank == 0:
save_checkpoint(self.G, self.checkpoint_path_G,self.optimizer_g, epoch)
save_checkpoint(self.D, self.checkpoint_path_D, self.optimizer_d, epoch)
self.generate_and_save(self.cfg.test_image_dir)
if epoch % 10 == 0:
self.copy_results(epoch)
num_iter += 1
if self.cfg.local_rank == 0:
end = time.time()
if end is None:
eta = 9999
else:
per_epoch_time = (end - start)
per_epoch_times.append(per_epoch_time)
eta = np.mean(per_epoch_times) * (self.cfg.epochs - epoch)
eta = convert_to_readable(eta)
self.logger.info(f"epoch {epoch}/{self.cfg.epochs}, ETA: {eta}")
def generate_and_save(
self,
image_dir,
max_imgs=15,
subname='gen'
):
'''
Generate and save images
'''
start = time.time()
self.G.eval()
max_iter = max_imgs
fake_imgs = []
real_imgs = []
image_files = glob(os.path.join(image_dir, "*"))
for i, image_file in enumerate(image_files):
image = read_image(image_file)
image = resize_image(image)
real_imgs.append(image.copy())
image = preprocess_images(image)
image = image.to(self.device)
with torch.no_grad():
with autocast(enabled=self.cfg.amp):
fake_img = self.G(image)
# fake_img = to_gray_scale(fake_img)
fake_img = fake_img.detach().cpu().numpy()
# Channel first -> channel last
fake_img = fake_img.transpose(0, 2, 3, 1)
fake_imgs.append(denormalize_input(fake_img, dtype=np.int16)[0])
if i + 1 == max_iter:
break
# fake_imgs = np.concatenate(fake_imgs, axis=0)
for i, (real_img, fake_img) in enumerate(zip(real_imgs, fake_imgs)):
img = np.concatenate((real_img, fake_img), axis=1) # Concate aross width
save_path = os.path.join(self.save_image_dir, f'{subname}_{i}.jpg')
if not cv2.imwrite(save_path, img[..., ::-1]):
self.logger.info(f"Save generated image failed, {save_path}, {img.shape}")
elapsed = time.time() - start
self.logger.info(f"Generated {len(fake_imgs)} images in {elapsed:.3f}s.")
def copy_results(self, epoch):
"""Copy result (Weight + Generated images) to each epoch folder
Every N epoch
"""
copy_dir = os.path.join(self.cfg.exp_dir, f"epoch_{epoch}")
os.makedirs(copy_dir, exist_ok=True)
shutil.copy2(
self.checkpoint_path_G,
copy_dir
)
dest = os.path.join(copy_dir, os.path.basename(self.save_image_dir))
shutil.copytree(
self.save_image_dir,
dest,
dirs_exist_ok=True
)