PTI / edit.py
ucalyptus's picture
simp
2d7efb8
raw
history blame
3.25 kB
import wandb
import click
import os
import sys
import pickle
import numpy as np
from PIL import Image
import glob
import torch
from configs import paths_config, hyperparameters, global_config
from IPython.display import display
import matplotlib.pyplot as plt
from scripts.latent_editor_wrapper import LatentEditorWrapper
image_dir_name = '/home/sayantan/processed_images'
use_multi_id_training = False
global_config.device = 'cuda'
paths_config.e4e = '/home/sayantan/PTI/pretrained_models/e4e_ffhq_encode.pt'
paths_config.input_data_id = image_dir_name
paths_config.input_data_path = f'{image_dir_name}'
paths_config.stylegan2_ada_ffhq = '/home/sayantan/PTI/pretrained_models/ffhq.pkl'
paths_config.checkpoints_dir = '/home/sayantan/PTI/'
paths_config.style_clip_pretrained_mappers = '/home/sayantan/PTI/pretrained_models'
hyperparameters.use_locality_regularization = False
hyperparameters.lpips_type = 'squeeze'
model_id = "MYJJDFVGATAT"
def display_alongside_source_image(images):
res = np.concatenate([np.array(image) for image in images], axis=1)
return Image.fromarray(res)
def load_generators(model_id, image_name):
with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
old_G = pickle.load(f)['G_ema'].cuda()
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_{image_name}.pt', 'rb') as f_new:
new_G = torch.load(f_new).cuda()
return old_G, new_G
def plot_syn_images(syn_images,text):
for img in syn_images:
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
plt.axis('off')
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
display(resized_image)
#wandb.log({text: [wandb.Image(resized_image, caption="Label")]})
del img
del resized_image
torch.cuda.empty_cache()
def syn_images_wandb(img):
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
plt.axis('off')
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
return resized_image
def edit(image_name):
generator_type = paths_config.multi_id_model_type if use_multi_id_training else image_name
old_G, new_G = load_generators(model_id, generator_type)
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
w_pivot = torch.load(f'{embedding_dir}/0.pt')
old_image = old_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
latent_editor = LatentEditorWrapper()
latents_after_edit = latent_editor.get_single_interface_gan_edits(w_pivot, [i for i in range(-5,5)])
for direction, factor_and_edit in latents_after_edit.items():
for editkey in factor_and_edit.keys():
new_image = new_G.synthesis(factor_and_edit[editkey], noise_mode='const', force_fp32 = True)
image_pil = syn_images_wandb(new_image).save(f"/home/sayantan/PTI/{direction}/{editkey}/{image_name}.jpg")
if __name__ == '__main__':
for image_name in [f.split(".")[0].split("_")[2] for f in sorted(glob.glob("*.pt"))]:
edit(image_name)