import pickle from PTI.utils.ImagesDataset import ImagesDataset, Image2Dataset import torch from PTI.utils.models_utils import load_old_G from PTI.utils.alignment import align_face from PTI.training.coaches.single_id_coach import SingleIDCoach from PTI.configs import global_config, paths_config import dlib import os from torchvision.transforms import transforms from torch.utils.data import DataLoader from string import ascii_uppercase import sys from pathlib import Path sys.path.append(".") # sys.path.append('PTI/') # sys.path.append('PTI/training/') def run_PTI(img, run_name): # os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' # os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices global_config.run_name = run_name global_config.pivotal_training_steps = 1 global_config.training_step = 1 embedding_dir_path = f"{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}" os.makedirs(embedding_dir_path, exist_ok=True) # dataset = ImagesDataset(paths_config.input_data_path, transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])) G = load_old_G() IMAGE_SIZE = 1024 predictor = dlib.shape_predictor(paths_config.dlib) aligned_image = align_face(img, predictor=predictor, output_size=IMAGE_SIZE) img = aligned_image.resize([G.img_resolution, G.img_resolution]) dataset = Image2Dataset(img) dataloader = DataLoader(dataset, batch_size=1, shuffle=False) coach = SingleIDCoach(dataloader, use_wandb=False) new_G, w_pivot = coach.train() return new_G, w_pivot def export_updated_pickle(new_G, out_path, run_name): image_name = "customIMG" with open(paths_config.stylegan2_ada_ffhq, "rb") as f: old_G = pickle.load(f)["G_ema"].cuda() embedding = Path(f"{paths_config.checkpoints_dir}/model_{run_name}_{image_name}.pt") with open(embedding, "rb") as f_new: new_G = torch.load(f_new).cuda() print("Exporting large updated pickle based off new generator and ffhq.pkl") with open(paths_config.stylegan2_ada_ffhq, "rb") as f: d = pickle.load(f) old_G = d["G_ema"].cuda() # tensor old_D = d["D"].eval().requires_grad_(False).cpu() tmp = {} tmp["G"] = old_G.eval().requires_grad_(False).cpu() tmp["G_ema"] = new_G.eval().requires_grad_(False).cpu() tmp["D"] = old_D tmp["training_set_kwargs"] = None tmp["augment_pipe"] = None with open(out_path, "wb") as f: pickle.dump(tmp, f) # delete embedding.unlink() # if __name__ == '__main__': # from PIL import Image # img = Image.open('PTI/test/test.jpg') # new_G, w_pivot = run_PTI(img, use_wandb=False, use_multi_id_training=False) # out_path = f'checkpoints/stylegan2_custom_512_pytorch.pkl' # export_updated_pickle(new_G, out_path)