import os from random import choice from string import ascii_uppercase from PIL import Image from tqdm import tqdm from scripts.latent_editor_wrapper import LatentEditorWrapper from evaluation.experiment_setting_creator import ExperimentRunner import torch from configs import paths_config, hyperparameters, evaluation_config from utils.log_utils import save_concat_image, save_single_image from utils.models_utils import load_tuned_G class EditComparison: def __init__(self, save_single_images, save_concatenated_images, run_id): self.run_id = run_id self.experiment_creator = ExperimentRunner(run_id) self.save_single_images = save_single_images self.save_concatenated_images = save_concatenated_images self.latent_editor = LatentEditorWrapper() def save_reconstruction_images(self, image_latents, new_inv_image_latent, new_G, target_image): if self.save_concatenated_images: save_concat_image(self.concat_base_dir, image_latents, new_inv_image_latent, new_G, self.experiment_creator.old_G, 'rec', target_image) if self.save_single_images: save_single_image(self.single_base_dir, new_inv_image_latent, new_G, 'rec') target_image.save(f'{self.single_base_dir}/Original.jpg') def create_output_dirs(self, full_image_name): output_base_dir_path = f'{paths_config.experiments_output_dir}/{paths_config.input_data_id}/{self.run_id}/{full_image_name}' os.makedirs(output_base_dir_path, exist_ok=True) self.concat_base_dir = f'{output_base_dir_path}/concat_images' self.single_base_dir = f'{output_base_dir_path}/single_images' os.makedirs(self.concat_base_dir, exist_ok=True) os.makedirs(self.single_base_dir, exist_ok=True) def get_image_latent_codes(self, image_name): image_latents = [] for method in evaluation_config.evaluated_methods: if method == 'SG2': image_latents.append(torch.load( f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/' f'{paths_config.pti_results_keyword}/{image_name}/0.pt')) else: image_latents.append(torch.load( f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{method}/{image_name}/0.pt')) new_inv_image_latent = torch.load( f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt') return image_latents, new_inv_image_latent def save_interfacegan_edits(self, image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image): new_w_inv_edits = self.latent_editor.get_single_interface_gan_edits(new_inv_image_latent, interfacegan_factors) inv_edits = [] for latent in image_latents: inv_edits.append(self.latent_editor.get_single_interface_gan_edits(latent, interfacegan_factors)) for direction, edits in new_w_inv_edits.items(): for factor, edit_tensor in edits.items(): if self.save_concatenated_images: save_concat_image(self.concat_base_dir, [edits[direction][factor] for edits in inv_edits], new_w_inv_edits[direction][factor], new_G, self.experiment_creator.old_G, f'{direction}_{factor}', target_image) if self.save_single_images: save_single_image(self.single_base_dir, new_w_inv_edits[direction][factor], new_G, f'{direction}_{factor}') def save_ganspace_edits(self, image_latents, new_inv_image_latent, factors, new_G, target_image): new_w_inv_edits = self.latent_editor.get_single_ganspace_edits(new_inv_image_latent, factors) inv_edits = [] for latent in image_latents: inv_edits.append(self.latent_editor.get_single_ganspace_edits(latent, factors)) for idx in range(len(new_w_inv_edits)): if self.save_concatenated_images: save_concat_image(self.concat_base_dir, [edit[idx] for edit in inv_edits], new_w_inv_edits[idx], new_G, self.experiment_creator.old_G, f'ganspace_{idx}', target_image) if self.save_single_images: save_single_image(self.single_base_dir, new_w_inv_edits[idx], new_G, f'ganspace_{idx}') def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): images_counter = 0 new_G = None interfacegan_factors = [val / 2 for val in range(-6, 7) if val != 0] ganspace_factors = range(-20, 25, 5) self.experiment_creator.run_experiment(run_pt, create_other_latents, use_multi_id_training, use_wandb) if use_multi_id_training: new_G = load_tuned_G(self.run_id, paths_config.multi_id_model_type) for idx, image_path in tqdm(enumerate(self.experiment_creator.images_paths), total=len(self.experiment_creator.images_paths)): if images_counter >= hyperparameters.max_images_to_invert: break image_name = image_path.split('.')[0].split('/')[-1] target_image = Image.open(self.experiment_creator.target_paths[idx]) if not use_multi_id_training: new_G = load_tuned_G(self.run_id, image_name) image_latents, new_inv_image_latent = self.get_image_latent_codes(image_name) self.create_output_dirs(image_name) self.save_reconstruction_images(image_latents, new_inv_image_latent, new_G, target_image) self.save_interfacegan_edits(image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image) self.save_ganspace_edits(image_latents, new_inv_image_latent, ganspace_factors, new_G, target_image) target_image.close() torch.cuda.empty_cache() images_counter += 1 def run_pti_and_full_edit(iid): evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, run_id=f'{paths_config.input_data_id}_pti_full_edit_{iid}') edit_figure_creator.run_experiment(True, True, use_multi_id_training=False, use_wandb=False) def pti_no_comparison(iid): evaluation_config.evaluated_methods = [] edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, run_id=f'{paths_config.input_data_id}_pti_no_comparison_{iid}') edit_figure_creator.run_experiment(True, False, use_multi_id_training=False, use_wandb=False) def edits_for_existed_experiment(run_id): evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, run_id=run_id) edit_figure_creator.run_experiment(False, True, use_multi_id_training=False, use_wandb=False) if __name__ == '__main__': iid = ''.join(choice(ascii_uppercase) for i in range(7)) pti_no_comparison(iid)