import abc import logging import pickle import os from random import choice from string import ascii_uppercase import torch from torch.utils.data import DataLoader from configs import global_config, paths_config from tqdm import tqdm from torchvision import transforms from utils.ImagesDataset import ImagesDataset class BaseLatentCreator: def __init__(self, method_name, dara_preprocess=None, use_wandb=False): global_config.run_name = ''.join(choice(ascii_uppercase) for i in range(12)) self.use_wandb = use_wandb if use_wandb: run = wandb.init(project="personalized_stylegan", reinit=True, name=global_config.run_name) os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices if dara_preprocess is None: self.projection_preprocess = transforms.Compose([ transforms.Resize(1024), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) else: self.projection_preprocess = dara_preprocess image_dataset = ImagesDataset(f'{paths_config.input_data_path}', self.projection_preprocess) self.image_dataloader = DataLoader(image_dataset, batch_size=1, shuffle=False) base_latent_folder_path = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' os.makedirs(base_latent_folder_path, exist_ok=True) self.latent_folder_path = f'{base_latent_folder_path}/{method_name}' os.makedirs(self.latent_folder_path, exist_ok=True) with open(paths_config.stylegan2_ada_ffhq, 'rb') as f: self.old_G = pickle.load(f)['G_ema'].cuda() @abc.abstractmethod def run_projection(self, fname, image): return None def create_latents(self): for fname, image in tqdm(self.image_dataloader): fname = fname[0] cur_latent_folder_path = f'{self.latent_folder_path}/{fname}' image = image.cuda() w = self.run_projection(fname, image) os.makedirs(cur_latent_folder_path, exist_ok=True) torch.save(w, f'{cur_latent_folder_path}/0.pt')