PTI / scripts /latent_creators /base_latent_creator.py
ucalyptus's picture
pushes
f799d59
import abc
import logging
import pickle
import os
from random import choice
from string import ascii_uppercase
import torch
from torch.utils.data import DataLoader
from configs import global_config, paths_config
from tqdm import tqdm
from torchvision import transforms
from utils.ImagesDataset import ImagesDataset
class BaseLatentCreator:
def __init__(self, method_name, dara_preprocess=None, use_wandb=False):
global_config.run_name = ''.join(choice(ascii_uppercase) for i in range(12))
self.use_wandb = use_wandb
if use_wandb:
run = wandb.init(project="personalized_stylegan", reinit=True, name=global_config.run_name)
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices
if dara_preprocess is None:
self.projection_preprocess = transforms.Compose([
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
else:
self.projection_preprocess = dara_preprocess
image_dataset = ImagesDataset(f'{paths_config.input_data_path}', self.projection_preprocess)
self.image_dataloader = DataLoader(image_dataset, batch_size=1, shuffle=False)
base_latent_folder_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
os.makedirs(base_latent_folder_path, exist_ok=True)
self.latent_folder_path = f'{base_latent_folder_path}/{method_name}'
os.makedirs(self.latent_folder_path, exist_ok=True)
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
self.old_G = pickle.load(f)['G_ema'].cuda()
@abc.abstractmethod
def run_projection(self, fname, image):
return None
def create_latents(self):
for fname, image in tqdm(self.image_dataloader):
fname = fname[0]
cur_latent_folder_path = f'{self.latent_folder_path}/{fname}'
image = image.cuda()
w = self.run_projection(fname, image)
os.makedirs(cur_latent_folder_path, exist_ok=True)
torch.save(w, f'{cur_latent_folder_path}/0.pt')