PTI / evaluation /qualitative_edit_comparison.py
ucalyptus's picture
simp
2d7efb8
import os
from random import choice
from string import ascii_uppercase
from PIL import Image
from tqdm import tqdm
from scripts.latent_editor_wrapper import LatentEditorWrapper
from evaluation.experiment_setting_creator import ExperimentRunner
import torch
from configs import paths_config, hyperparameters, evaluation_config
from utils.log_utils import save_concat_image, save_single_image
from utils.models_utils import load_tuned_G
class EditComparison:
def __init__(self, save_single_images, save_concatenated_images, run_id):
self.run_id = run_id
self.experiment_creator = ExperimentRunner(run_id)
self.save_single_images = save_single_images
self.save_concatenated_images = save_concatenated_images
self.latent_editor = LatentEditorWrapper()
def save_reconstruction_images(self, image_latents, new_inv_image_latent, new_G, target_image):
if self.save_concatenated_images:
save_concat_image(self.concat_base_dir, image_latents, new_inv_image_latent, new_G,
self.experiment_creator.old_G,
'rec',
target_image)
if self.save_single_images:
save_single_image(self.single_base_dir, new_inv_image_latent, new_G, 'rec')
target_image.save(f'{self.single_base_dir}/Original.jpg')
def create_output_dirs(self, full_image_name):
output_base_dir_path = f'{paths_config.experiments_output_dir}/{paths_config.input_data_id}/{self.run_id}/{full_image_name}'
os.makedirs(output_base_dir_path, exist_ok=True)
self.concat_base_dir = f'{output_base_dir_path}/concat_images'
self.single_base_dir = f'{output_base_dir_path}/single_images'
os.makedirs(self.concat_base_dir, exist_ok=True)
os.makedirs(self.single_base_dir, exist_ok=True)
def get_image_latent_codes(self, image_name):
image_latents = []
for method in evaluation_config.evaluated_methods:
if method == 'SG2':
image_latents.append(torch.load(
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/'
f'{paths_config.pti_results_keyword}/{image_name}/0.pt'))
else:
image_latents.append(torch.load(
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{method}/{image_name}/0.pt'))
new_inv_image_latent = torch.load(
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt')
return image_latents, new_inv_image_latent
def save_interfacegan_edits(self, image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image):
new_w_inv_edits = self.latent_editor.get_single_interface_gan_edits(new_inv_image_latent,
interfacegan_factors)
inv_edits = []
for latent in image_latents:
inv_edits.append(self.latent_editor.get_single_interface_gan_edits(latent, interfacegan_factors))
for direction, edits in new_w_inv_edits.items():
for factor, edit_tensor in edits.items():
if self.save_concatenated_images:
save_concat_image(self.concat_base_dir, [edits[direction][factor] for edits in inv_edits],
new_w_inv_edits[direction][factor],
new_G,
self.experiment_creator.old_G,
f'{direction}_{factor}', target_image)
if self.save_single_images:
save_single_image(self.single_base_dir, new_w_inv_edits[direction][factor], new_G,
f'{direction}_{factor}')
def save_ganspace_edits(self, image_latents, new_inv_image_latent, factors, new_G, target_image):
new_w_inv_edits = self.latent_editor.get_single_ganspace_edits(new_inv_image_latent, factors)
inv_edits = []
for latent in image_latents:
inv_edits.append(self.latent_editor.get_single_ganspace_edits(latent, factors))
for idx in range(len(new_w_inv_edits)):
if self.save_concatenated_images:
save_concat_image(self.concat_base_dir, [edit[idx] for edit in inv_edits], new_w_inv_edits[idx],
new_G,
self.experiment_creator.old_G,
f'ganspace_{idx}', target_image)
if self.save_single_images:
save_single_image(self.single_base_dir, new_w_inv_edits[idx], new_G,
f'ganspace_{idx}')
def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False):
images_counter = 0
new_G = None
interfacegan_factors = [val / 2 for val in range(-6, 7) if val != 0]
ganspace_factors = range(-20, 25, 5)
self.experiment_creator.run_experiment(run_pt, create_other_latents, use_multi_id_training, use_wandb)
if use_multi_id_training:
new_G = load_tuned_G(self.run_id, paths_config.multi_id_model_type)
for idx, image_path in tqdm(enumerate(self.experiment_creator.images_paths),
total=len(self.experiment_creator.images_paths)):
if images_counter >= hyperparameters.max_images_to_invert:
break
image_name = image_path.split('.')[0].split('/')[-1]
target_image = Image.open(self.experiment_creator.target_paths[idx])
if not use_multi_id_training:
new_G = load_tuned_G(self.run_id, image_name)
image_latents, new_inv_image_latent = self.get_image_latent_codes(image_name)
self.create_output_dirs(image_name)
self.save_reconstruction_images(image_latents, new_inv_image_latent, new_G, target_image)
self.save_interfacegan_edits(image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image)
self.save_ganspace_edits(image_latents, new_inv_image_latent, ganspace_factors, new_G, target_image)
target_image.close()
torch.cuda.empty_cache()
images_counter += 1
def run_pti_and_full_edit(iid):
evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2']
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True,
run_id=f'{paths_config.input_data_id}_pti_full_edit_{iid}')
edit_figure_creator.run_experiment(True, True, use_multi_id_training=False, use_wandb=False)
def pti_no_comparison(iid):
evaluation_config.evaluated_methods = []
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True,
run_id=f'{paths_config.input_data_id}_pti_no_comparison_{iid}')
edit_figure_creator.run_experiment(True, False, use_multi_id_training=False, use_wandb=False)
def edits_for_existed_experiment(run_id):
evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2']
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True,
run_id=run_id)
edit_figure_creator.run_experiment(False, True, use_multi_id_training=False, use_wandb=False)
if __name__ == '__main__':
iid = ''.join(choice(ascii_uppercase) for i in range(7))
pti_no_comparison(iid)