Spaces:
Build error
Build error
import argparse | |
import os | |
import sys | |
from argparse import Namespace | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import wandb | |
from PIL import Image | |
from joblib import Parallel, delayed | |
from sklearn.model_selection import train_test_split | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms as T | |
from tqdm.auto import tqdm | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
from models.Encoders import ClipBlendingModel as BlendingModel | |
from models.Net import Net | |
from models.face_parsing.model import BiSeNet, seg_mean, seg_std | |
from utils.bicubic import BicubicDownSample | |
from utils.image_utils import DilateErosion | |
from utils.train import toggle_grad, WandbLogger, image_grid, seed_everything, get_fid_calc | |
class Trainer: | |
def __init__(self, | |
model=None, | |
optimizer=None, | |
scheduler=None, | |
train_dataloader=None, | |
test_dataloader=None, | |
logger=None, | |
): | |
self.model = model | |
self.optimizer = optimizer | |
self.scheduler = scheduler | |
self.train_dataloader = train_dataloader | |
self.test_dataloader = test_dataloader | |
self.logger = logger | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.dilate_erosion = DilateErosion(device=self.device) | |
if self.model is not None: | |
self.fid_calc = get_fid_calc('input/fid.pkl', args.fid_dataset) | |
self.net = Net(Namespace(size=1024, ckpt='pretrained_models/StyleGAN/ffhq.pt', channel_multiplier=2, latent=512, | |
n_mlp=8, device=self.device)) | |
self.seg = BiSeNet(n_classes=16) | |
self.seg.to(self.device) | |
self.seg.eval() | |
self.seg.load_state_dict(torch.load('pretrained_models/BiSeNet/seg.pth')) | |
toggle_grad(self.seg, False) | |
toggle_grad(self.net.generator, False) | |
self.downsample_512 = BicubicDownSample(factor=2) | |
self.downsample_256 = BicubicDownSample(factor=4) | |
self.downsample_128 = BicubicDownSample(factor=8) | |
self.best_loss = float('+inf') | |
self.cur_iter = 0 | |
def generate_mask(self, I): | |
IM = (self.downsample_512((I + 1) / 2) - seg_mean) / seg_std | |
down_seg, _, _ = self.seg(IM) | |
current_mask = torch.argmax(down_seg, dim=1).long().float() | |
HM_X = torch.where(current_mask == 10, torch.ones_like(current_mask), torch.zeros_like(current_mask)) | |
HM_X = F.interpolate(HM_X.unsqueeze(1), size=(256, 256), mode='nearest') | |
HM_XD, HM_XE = self.dilate_erosion.mask(HM_X) | |
return HM_XD, HM_XE | |
def save_model(self, name, save_online=True): | |
with TemporaryDirectory() as tmp_dir: | |
model_state_dict = self.model.state_dict() | |
# delete pretrained clip | |
for key in list(model_state_dict.keys()): | |
if key.startswith("clip_model."): | |
del model_state_dict[key] | |
torch.save({'model_state_dict': model_state_dict}, f'{tmp_dir}/{name}.pth') | |
self.logger.save(f'{tmp_dir}/{name}.pth', save_online) | |
def calc_loss(self, I_gen, I_face, I_color, mask_face, mask_hair, gen_hair): | |
gen_embed = self.model.get_image_embed(I_gen * mask_face) | |
gt_embed = self.model.get_image_embed(I_face * mask_face) | |
face_loss = (1 - F.cosine_similarity(gen_embed, gt_embed)).mean() | |
gen_embed = self.model.get_image_embed(I_gen * mask_hair) | |
gt_embed = self.model.get_image_embed(I_color * mask_hair) | |
hair_loss = (1 - F.cosine_similarity(gen_embed, gt_embed)).mean() | |
losses = {'face loss': face_loss, 'hair loss': hair_loss, 'loss': face_loss + hair_loss} | |
return losses['loss'], losses | |
def train_one_epoch(self): | |
self.model.to(self.device).train() | |
for batch in tqdm(self.train_dataloader): | |
color_s, align_s, align_f, color_i, face_i, target_mask, HM_3E, HM_XE = map(lambda x: x.to(self.device), | |
batch) | |
bsz = color_s.size(0) | |
blend_s = self.model(align_s[:, 6:], color_s[:, 6:], face_i * target_mask, color_i * HM_3E) | |
latent_in = torch.cat((torch.zeros(bsz, 6, 512, device=self.device), blend_s), axis=1) | |
I_G, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False, start_layer=4, | |
end_layer=8, layer_in=align_f) | |
loss, info = self.calc_loss(self.downsample_256(I_G), face_i, color_i, target_mask, HM_3E, HM_XE) | |
self.optimizer.zero_grad() | |
loss.backward() | |
total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5) | |
self.optimizer.step() | |
self.logger.next_step() | |
for key, val in info.items(): | |
self.logger.log(key, val.item()) | |
self.logger.log('grad', total_norm.item()) | |
self.cur_iter += 1 | |
def validate(self): | |
self.model.to(self.device).eval() | |
sum_losses = lambda x, y: {key: val + x.get(key, 0) for key, val in y.items()} | |
files = [] | |
losses = {} | |
to_299 = T.Resize((299, 299)) | |
images_to_fid = [] | |
for batch in tqdm(self.test_dataloader): | |
color_s, align_s, align_f, color_i, face_i, target_mask, HM_3E, HM_XE = map(lambda x: x.to(self.device), | |
batch) | |
bsz = color_s.size(0) | |
blend_s = self.model(align_s[:, 6:], color_s[:, 6:], face_i * target_mask, color_i * HM_3E) | |
latent_in = torch.cat((torch.zeros(bsz, 6, 512, device=self.device), blend_s), axis=1) | |
I_G, _ = self.net.generator([latent_in], input_is_latent=True, return_latents=False, start_layer=4, | |
end_layer=8, layer_in=align_f) | |
_, info = self.calc_loss(self.downsample_256(I_G), face_i, color_i, target_mask, HM_3E, HM_XE) | |
losses = sum_losses(losses, info) | |
for k in range(bsz): | |
files.append([color_i[k].cpu(), face_i[k].cpu(), self.downsample_256(I_G)[k].cpu()]) | |
images_to_fid.append(to_299((I_G + 1) / 2).clip(0, 1)) | |
losses['FID CLIP'] = self.fid_calc(torch.cat(images_to_fid)) | |
for key, val in losses.items(): | |
if key != 'FID CLIP': | |
val = val.item() / len(self.test_dataloader) | |
self.logger.log(f'val {key}', val) | |
np.random.seed(1927) | |
idxs = np.random.choice(len(files), size=100, replace=False) | |
images_to_log = [ | |
image_grid([T.functional.to_pil_image(((img + 1) / 2).clamp(0, 1)) for img in files[idx]], 1, 3) for idx in | |
idxs] | |
self.logger.log('val images', [wandb.Image(image) for image in images_to_log]) | |
return losses['loss'] | |
def train_loop(self, epochs): | |
self.validate() | |
for epoch in range(epochs): | |
self.train_one_epoch() | |
loss = self.validate() | |
self.save_model('last', save_online=False) | |
if loss <= self.best_loss: | |
self.best_loss = loss | |
self.save_model(f'best', save_online=False) | |
def prepare_item(exp, path): | |
im1, im2, im3 = exp | |
try: | |
color_path = os.path.join(path, 'FS', f'{im3}.npz') | |
Color_S = torch.from_numpy(np.load(color_path)['latent_in']).squeeze(0) | |
face_path = os.path.join(path, 'FS', f'{im1}.npz') | |
Align_S = torch.from_numpy(np.load(face_path)['latent_in']).squeeze(0) | |
Color_I = T.functional.normalize(T.functional.to_tensor( | |
Image.open(os.path.join(args.FFHQ, f'{im3}.png')) | |
), [0.5], [0.5]) | |
Face_I = T.functional.normalize(T.functional.to_tensor( | |
Image.open(os.path.join(args.FFHQ, f'{im1}.png')) | |
), [0.5], [0.5]) | |
align_path = os.path.join(path, 'Align') | |
data = np.load( | |
os.path.join(align_path, f'{im1}_{im3}.npz') | |
) | |
Align_F = torch.from_numpy(data['latent_F']).squeeze(0) | |
return (Color_S, Align_S, Align_F, Color_I, Face_I) | |
except Exception as e: | |
print(e, file=sys.stderr) | |
return None | |
class Blending_dataset(Dataset): | |
def __init__(self, exps, path, net_trainer): | |
super().__init__() | |
downsample_256 = BicubicDownSample(factor=4) | |
data = Parallel(n_jobs=-1)( | |
delayed(prepare_item)(exp, path) for (p1, p2, p3) in tqdm(exps) for exp in [(p1, p2, p3), (p1, p3, p2)]) | |
data = [elem for elem in data if elem is not None] | |
print(f'Load: {len(data)}/{2 * len(exps)}', file=sys.stderr) | |
tmp_dataloader = DataLoader(data, batch_size=24, pin_memory=False, shuffle=False) | |
self.items = [] | |
with torch.no_grad(): | |
for (Color_S, Align_S, Align_F, Color_I, Face_I) in tqdm(tmp_dataloader): | |
HM_3D, HM_3E = net_trainer.generate_mask(Color_I.to('cuda')) | |
HM_1D, _ = net_trainer.generate_mask(Face_I.to('cuda')) | |
I_X, _ = net_trainer.net.generator([Align_S.to('cuda')], input_is_latent=True, return_latents=False, | |
start_layer=4, | |
end_layer=8, layer_in=Align_F.to('cuda')) | |
HM_XD, HM_XE = net_trainer.generate_mask(I_X) | |
target_mask = ((1 - HM_1D) * (1 - HM_3D) * (1 - HM_XD)).cpu() | |
HM_3E = HM_3E.cpu() | |
HM_XE = HM_XE | |
self.items.extend( | |
[item for item in zip(*list(map(lambda x: [item.squeeze(0) for item in torch.split(x, 1)], | |
(Color_S, | |
Align_S, | |
Align_F, | |
downsample_256(Color_I.to('cuda')).cpu(), | |
downsample_256(Face_I.to('cuda')).cpu(), | |
target_mask, HM_3E, HM_XE))) | |
) if item[-2].any() and item[-1].any()] | |
) | |
print(f'dataset: {len(self.items)}/{len(data)}', file=sys.stderr) | |
def __len__(self): | |
return len(self.items) | |
def __getitem__(self, idx): | |
return self.items[idx] | |
def main(args): | |
seed_everything() | |
exps = [] | |
with open(os.path.join(args.dataset, 'dataset.exps'), 'r') as file: | |
for exp in file.readlines(): | |
exps.append(list(map(lambda x: x.replace('.png', ''), exp.split()))) | |
X_train, X_test = train_test_split(exps, test_size=512, random_state=42) | |
net_trainer = Trainer() | |
train_dataset = Blending_dataset(X_train, args.dataset, net_trainer) | |
test_dataset = Blending_dataset(X_test, args.dataset, net_trainer) | |
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=True) | |
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) | |
logger = WandbLogger(name=args.name_run, project='Barbershop-Blending') | |
logger.start_logging() | |
logger.save(__file__) | |
model = BlendingModel() | |
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0.000001) | |
trainer = Trainer(model, optimizer, None, train_dataloader, test_dataloader, logger) | |
trainer.train_loop(1000) | |
logger.wandb.finish() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='Blending trainer') | |
parser.add_argument('--name_run', type=str, default='test') | |
parser.add_argument('--dataset', type=Path, default='input/blending_dataset') | |
parser.add_argument('--FFHQ', type=Path) | |
parser.add_argument('--fid_dataset', type=str, default='input') | |
args = parser.parse_args() | |
main(args) | |