PTI / tune.py
ucalyptus's picture
Update tune.py
454970e
raw history blame
No virus
1.87 kB
import click
import os
import sys
import pickle
import numpy as np
from PIL import Image
import torch
from configs import paths_config, hyperparameters, global_config
from IPython.display import display
import matplotlib.pyplot as plt
from scripts.latent_editor_wrapper import LatentEditorWrapper
image_dir_name = 'images'
use_multi_id_training = False
global_config.device = 'cuda'
paths_config.e4e = 'e4e_ffhq_encode.pt'
paths_config.input_data_id = image_dir_name
paths_config.input_data_path = f'{image_dir_name}'
paths_config.stylegan2_ada_ffhq = 'ffhq.pkl'
paths_config.checkpoints_dir = 'checkpoints'
paths_config.style_clip_pretrained_mappers = ''
hyperparameters.use_locality_regularization = False
hyperparameters.lpips_type = 'squeeze'
from scripts.run_pti import run_PTI
def load_generator(model_id):
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_file.pt', 'rb') as f_new:
new_G = torch.load(f_new).cuda()
return new_G
def tensor_to_pil(img):
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
plt.axis('off')
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
return resized_image
def tune():
model_id = run_PTI(run_name='',use_wandb=False, use_multi_id_training=False)
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/file'
w_pivot = torch.load(f'{embedding_dir}/0.pt')
new_G = load_generator(model_id)
new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
tensor_to_pil(new_image).save("output/out.png")
#----------------------------------------------------------------------------
if __name__ == '__main__':
tune()
#----------------------------------------------------------------------------