newTryOn / scripts /blending_train.py
amanSethSmava
new commit
6d314be
raw
history blame
12.2 kB
import argparse
import os
import sys
from argparse import Namespace
from pathlib import Path
from tempfile import TemporaryDirectory
import numpy as np
import torch
import torch.nn.functional as F
import wandb
from PIL import Image
from joblib import Parallel, delayed
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from tqdm.auto import tqdm
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from models.Encoders import ClipBlendingModel as BlendingModel
from models.Net import Net
from models.face_parsing.model import BiSeNet, seg_mean, seg_std
from utils.bicubic import BicubicDownSample
from utils.image_utils import DilateErosion
from utils.train import toggle_grad, WandbLogger, image_grid, seed_everything, get_fid_calc
class Trainer:
def __init__(self,
model=None,
optimizer=None,
scheduler=None,
train_dataloader=None,
test_dataloader=None,
logger=None,
):
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.logger = logger
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.dilate_erosion = DilateErosion(device=self.device)
if self.model is not None:
self.fid_calc = get_fid_calc('input/fid.pkl', args.fid_dataset)
self.net = Net(Namespace(size=1024, ckpt='pretrained_models/StyleGAN/ffhq.pt', channel_multiplier=2, latent=512,
n_mlp=8, device=self.device))
self.seg = BiSeNet(n_classes=16)
self.seg.to(self.device)
self.seg.eval()
self.seg.load_state_dict(torch.load('pretrained_models/BiSeNet/seg.pth'))
toggle_grad(self.seg, False)
toggle_grad(self.net.generator, False)
self.downsample_512 = BicubicDownSample(factor=2)
self.downsample_256 = BicubicDownSample(factor=4)
self.downsample_128 = BicubicDownSample(factor=8)
self.best_loss = float('+inf')
self.cur_iter = 0
@torch.no_grad()
def generate_mask(self, I):
IM = (self.downsample_512((I + 1) / 2) - seg_mean) / seg_std
down_seg, _, _ = self.seg(IM)
current_mask = torch.argmax(down_seg, dim=1).long().float()
HM_X = torch.where(current_mask == 10, torch.ones_like(current_mask), torch.zeros_like(current_mask))
HM_X = F.interpolate(HM_X.unsqueeze(1), size=(256, 256), mode='nearest')
HM_XD, HM_XE = self.dilate_erosion.mask(HM_X)
return HM_XD, HM_XE
def save_model(self, name, save_online=True):
with TemporaryDirectory() as tmp_dir:
model_state_dict = self.model.state_dict()
# delete pretrained clip
for key in list(model_state_dict.keys()):
if key.startswith("clip_model."):
del model_state_dict[key]
torch.save({'model_state_dict': model_state_dict}, f'{tmp_dir}/{name}.pth')
self.logger.save(f'{tmp_dir}/{name}.pth', save_online)
def calc_loss(self, I_gen, I_face, I_color, mask_face, mask_hair, gen_hair):
gen_embed = self.model.get_image_embed(I_gen * mask_face)
gt_embed = self.model.get_image_embed(I_face * mask_face)
face_loss = (1 - F.cosine_similarity(gen_embed, gt_embed)).mean()
gen_embed = self.model.get_image_embed(I_gen * mask_hair)
gt_embed = self.model.get_image_embed(I_color * mask_hair)
hair_loss = (1 - F.cosine_similarity(gen_embed, gt_embed)).mean()
losses = {'face loss': face_loss, 'hair loss': hair_loss, 'loss': face_loss + hair_loss}
return losses['loss'], losses
def train_one_epoch(self):
self.model.to(self.device).train()
for batch in tqdm(self.train_dataloader):
color_s, align_s, align_f, color_i, face_i, target_mask, HM_3E, HM_XE = map(lambda x: x.to(self.device),
batch)
bsz = color_s.size(0)
blend_s = self.model(align_s[:, 6:], color_s[:, 6:], face_i * target_mask, color_i * HM_3E)
latent_in = torch.cat((torch.zeros(bsz, 6, 512, device=self.device), blend_s), axis=1)
I_G, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False, start_layer=4,
end_layer=8, layer_in=align_f)
loss, info = self.calc_loss(self.downsample_256(I_G), face_i, color_i, target_mask, HM_3E, HM_XE)
self.optimizer.zero_grad()
loss.backward()
total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5)
self.optimizer.step()
self.logger.next_step()
for key, val in info.items():
self.logger.log(key, val.item())
self.logger.log('grad', total_norm.item())
self.cur_iter += 1
@torch.no_grad()
def validate(self):
self.model.to(self.device).eval()
sum_losses = lambda x, y: {key: val + x.get(key, 0) for key, val in y.items()}
files = []
losses = {}
to_299 = T.Resize((299, 299))
images_to_fid = []
for batch in tqdm(self.test_dataloader):
color_s, align_s, align_f, color_i, face_i, target_mask, HM_3E, HM_XE = map(lambda x: x.to(self.device),
batch)
bsz = color_s.size(0)
blend_s = self.model(align_s[:, 6:], color_s[:, 6:], face_i * target_mask, color_i * HM_3E)
latent_in = torch.cat((torch.zeros(bsz, 6, 512, device=self.device), blend_s), axis=1)
I_G, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False, start_layer=4,
end_layer=8, layer_in=align_f)
_, info = self.calc_loss(self.downsample_256(I_G), face_i, color_i, target_mask, HM_3E, HM_XE)
losses = sum_losses(losses, info)
for k in range(bsz):
files.append([color_i[k].cpu(), face_i[k].cpu(), self.downsample_256(I_G)[k].cpu()])
images_to_fid.append(to_299((I_G + 1) / 2).clip(0, 1))
losses['FID CLIP'] = self.fid_calc(torch.cat(images_to_fid))
for key, val in losses.items():
if key != 'FID CLIP':
val = val.item() / len(self.test_dataloader)
self.logger.log(f'val {key}', val)
np.random.seed(1927)
idxs = np.random.choice(len(files), size=100, replace=False)
images_to_log = [
image_grid([T.functional.to_pil_image(((img + 1) / 2).clamp(0, 1)) for img in files[idx]], 1, 3) for idx in
idxs]
self.logger.log('val images', [wandb.Image(image) for image in images_to_log])
return losses['loss']
def train_loop(self, epochs):
self.validate()
for epoch in range(epochs):
self.train_one_epoch()
loss = self.validate()
self.save_model('last', save_online=False)
if loss <= self.best_loss:
self.best_loss = loss
self.save_model(f'best', save_online=False)
def prepare_item(exp, path):
im1, im2, im3 = exp
try:
color_path = os.path.join(path, 'FS', f'{im3}.npz')
Color_S = torch.from_numpy(np.load(color_path)['latent_in']).squeeze(0)
face_path = os.path.join(path, 'FS', f'{im1}.npz')
Align_S = torch.from_numpy(np.load(face_path)['latent_in']).squeeze(0)
Color_I = T.functional.normalize(T.functional.to_tensor(
Image.open(os.path.join(args.FFHQ, f'{im3}.png'))
), [0.5], [0.5])
Face_I = T.functional.normalize(T.functional.to_tensor(
Image.open(os.path.join(args.FFHQ, f'{im1}.png'))
), [0.5], [0.5])
align_path = os.path.join(path, 'Align')
data = np.load(
os.path.join(align_path, f'{im1}_{im3}.npz')
)
Align_F = torch.from_numpy(data['latent_F']).squeeze(0)
return (Color_S, Align_S, Align_F, Color_I, Face_I)
except Exception as e:
print(e, file=sys.stderr)
return None
class Blending_dataset(Dataset):
def __init__(self, exps, path, net_trainer):
super().__init__()
downsample_256 = BicubicDownSample(factor=4)
data = Parallel(n_jobs=-1)(
delayed(prepare_item)(exp, path) for (p1, p2, p3) in tqdm(exps) for exp in [(p1, p2, p3), (p1, p3, p2)])
data = [elem for elem in data if elem is not None]
print(f'Load: {len(data)}/{2 * len(exps)}', file=sys.stderr)
tmp_dataloader = DataLoader(data, batch_size=24, pin_memory=False, shuffle=False)
self.items = []
with torch.no_grad():
for (Color_S, Align_S, Align_F, Color_I, Face_I) in tqdm(tmp_dataloader):
HM_3D, HM_3E = net_trainer.generate_mask(Color_I.to('cuda'))
HM_1D, _ = net_trainer.generate_mask(Face_I.to('cuda'))
I_X, _ = net_trainer.net.generator([Align_S.to('cuda')], input_is_latent=True, return_latents=False,
start_layer=4,
end_layer=8, layer_in=Align_F.to('cuda'))
HM_XD, HM_XE = net_trainer.generate_mask(I_X)
target_mask = ((1 - HM_1D) * (1 - HM_3D) * (1 - HM_XD)).cpu()
HM_3E = HM_3E.cpu()
HM_XE = HM_XE
self.items.extend(
[item for item in zip(*list(map(lambda x: [item.squeeze(0) for item in torch.split(x, 1)],
(Color_S,
Align_S,
Align_F,
downsample_256(Color_I.to('cuda')).cpu(),
downsample_256(Face_I.to('cuda')).cpu(),
target_mask, HM_3E, HM_XE)))
) if item[-2].any() and item[-1].any()]
)
print(f'dataset: {len(self.items)}/{len(data)}', file=sys.stderr)
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
return self.items[idx]
def main(args):
seed_everything()
exps = []
with open(os.path.join(args.dataset, 'dataset.exps'), 'r') as file:
for exp in file.readlines():
exps.append(list(map(lambda x: x.replace('.png', ''), exp.split())))
X_train, X_test = train_test_split(exps, test_size=512, random_state=42)
net_trainer = Trainer()
train_dataset = Blending_dataset(X_train, args.dataset, net_trainer)
test_dataset = Blending_dataset(X_test, args.dataset, net_trainer)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
logger = WandbLogger(name=args.name_run, project='Barbershop-Blending')
logger.start_logging()
logger.save(__file__)
model = BlendingModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.000001)
trainer = Trainer(model, optimizer, None, train_dataloader, test_dataloader, logger)
trainer.train_loop(1000)
logger.wandb.finish()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Blending trainer')
parser.add_argument('--name_run', type=str, default='test')
parser.add_argument('--dataset', type=Path, default='input/blending_dataset')
parser.add_argument('--FFHQ', type=Path)
parser.add_argument('--fid_dataset', type=str, default='input')
args = parser.parse_args()
main(args)