|
import os |
|
from random import choice |
|
from string import ascii_uppercase |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from scripts.latent_editor_wrapper import LatentEditorWrapper |
|
from evaluation.experiment_setting_creator import ExperimentRunner |
|
import torch |
|
from configs import paths_config, hyperparameters, evaluation_config |
|
from utils.log_utils import save_concat_image, save_single_image |
|
from utils.models_utils import load_tuned_G |
|
|
|
|
|
class EditComparison: |
|
|
|
def __init__(self, save_single_images, save_concatenated_images, run_id): |
|
|
|
self.run_id = run_id |
|
self.experiment_creator = ExperimentRunner(run_id) |
|
self.save_single_images = save_single_images |
|
self.save_concatenated_images = save_concatenated_images |
|
self.latent_editor = LatentEditorWrapper() |
|
|
|
def save_reconstruction_images(self, image_latents, new_inv_image_latent, new_G, target_image): |
|
if self.save_concatenated_images: |
|
save_concat_image(self.concat_base_dir, image_latents, new_inv_image_latent, new_G, |
|
self.experiment_creator.old_G, |
|
'rec', |
|
target_image) |
|
|
|
if self.save_single_images: |
|
save_single_image(self.single_base_dir, new_inv_image_latent, new_G, 'rec') |
|
target_image.save(f'{self.single_base_dir}/Original.jpg') |
|
|
|
def create_output_dirs(self, full_image_name): |
|
output_base_dir_path = f'{paths_config.experiments_output_dir}/{paths_config.input_data_id}/{self.run_id}/{full_image_name}' |
|
os.makedirs(output_base_dir_path, exist_ok=True) |
|
|
|
self.concat_base_dir = f'{output_base_dir_path}/concat_images' |
|
self.single_base_dir = f'{output_base_dir_path}/single_images' |
|
|
|
os.makedirs(self.concat_base_dir, exist_ok=True) |
|
os.makedirs(self.single_base_dir, exist_ok=True) |
|
|
|
def get_image_latent_codes(self, image_name): |
|
image_latents = [] |
|
for method in evaluation_config.evaluated_methods: |
|
if method == 'SG2': |
|
image_latents.append(torch.load( |
|
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/' |
|
f'{paths_config.pti_results_keyword}/{image_name}/0.pt')) |
|
else: |
|
image_latents.append(torch.load( |
|
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{method}/{image_name}/0.pt')) |
|
new_inv_image_latent = torch.load( |
|
f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}/{paths_config.pti_results_keyword}/{image_name}/0.pt') |
|
|
|
return image_latents, new_inv_image_latent |
|
|
|
def save_interfacegan_edits(self, image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image): |
|
new_w_inv_edits = self.latent_editor.get_single_interface_gan_edits(new_inv_image_latent, |
|
interfacegan_factors) |
|
|
|
inv_edits = [] |
|
for latent in image_latents: |
|
inv_edits.append(self.latent_editor.get_single_interface_gan_edits(latent, interfacegan_factors)) |
|
|
|
for direction, edits in new_w_inv_edits.items(): |
|
for factor, edit_tensor in edits.items(): |
|
if self.save_concatenated_images: |
|
save_concat_image(self.concat_base_dir, [edits[direction][factor] for edits in inv_edits], |
|
new_w_inv_edits[direction][factor], |
|
new_G, |
|
self.experiment_creator.old_G, |
|
f'{direction}_{factor}', target_image) |
|
if self.save_single_images: |
|
save_single_image(self.single_base_dir, new_w_inv_edits[direction][factor], new_G, |
|
f'{direction}_{factor}') |
|
|
|
def save_ganspace_edits(self, image_latents, new_inv_image_latent, factors, new_G, target_image): |
|
new_w_inv_edits = self.latent_editor.get_single_ganspace_edits(new_inv_image_latent, factors) |
|
inv_edits = [] |
|
for latent in image_latents: |
|
inv_edits.append(self.latent_editor.get_single_ganspace_edits(latent, factors)) |
|
|
|
for idx in range(len(new_w_inv_edits)): |
|
if self.save_concatenated_images: |
|
save_concat_image(self.concat_base_dir, [edit[idx] for edit in inv_edits], new_w_inv_edits[idx], |
|
new_G, |
|
self.experiment_creator.old_G, |
|
f'ganspace_{idx}', target_image) |
|
if self.save_single_images: |
|
save_single_image(self.single_base_dir, new_w_inv_edits[idx], new_G, |
|
f'ganspace_{idx}') |
|
|
|
def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): |
|
images_counter = 0 |
|
new_G = None |
|
interfacegan_factors = [val / 2 for val in range(-6, 7) if val != 0] |
|
ganspace_factors = range(-20, 25, 5) |
|
self.experiment_creator.run_experiment(run_pt, create_other_latents, use_multi_id_training, use_wandb) |
|
|
|
if use_multi_id_training: |
|
new_G = load_tuned_G(self.run_id, paths_config.multi_id_model_type) |
|
|
|
for idx, image_path in tqdm(enumerate(self.experiment_creator.images_paths), |
|
total=len(self.experiment_creator.images_paths)): |
|
|
|
if images_counter >= hyperparameters.max_images_to_invert: |
|
break |
|
|
|
image_name = image_path.split('.')[0].split('/')[-1] |
|
target_image = Image.open(self.experiment_creator.target_paths[idx]) |
|
|
|
if not use_multi_id_training: |
|
new_G = load_tuned_G(self.run_id, image_name) |
|
|
|
image_latents, new_inv_image_latent = self.get_image_latent_codes(image_name) |
|
|
|
self.create_output_dirs(image_name) |
|
|
|
self.save_reconstruction_images(image_latents, new_inv_image_latent, new_G, target_image) |
|
|
|
self.save_interfacegan_edits(image_latents, new_inv_image_latent, interfacegan_factors, new_G, target_image) |
|
|
|
self.save_ganspace_edits(image_latents, new_inv_image_latent, ganspace_factors, new_G, target_image) |
|
|
|
target_image.close() |
|
torch.cuda.empty_cache() |
|
images_counter += 1 |
|
|
|
|
|
def run_pti_and_full_edit(iid): |
|
evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] |
|
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, |
|
run_id=f'{paths_config.input_data_id}_pti_full_edit_{iid}') |
|
edit_figure_creator.run_experiment(True, True, use_multi_id_training=False, use_wandb=False) |
|
|
|
|
|
def pti_no_comparison(iid): |
|
evaluation_config.evaluated_methods = [] |
|
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, |
|
run_id=f'{paths_config.input_data_id}_pti_no_comparison_{iid}') |
|
edit_figure_creator.run_experiment(True, False, use_multi_id_training=False, use_wandb=False) |
|
|
|
|
|
def edits_for_existed_experiment(run_id): |
|
evaluation_config.evaluated_methods = ['SG2Plus', 'e4e', 'SG2'] |
|
edit_figure_creator = EditComparison(save_single_images=True, save_concatenated_images=True, |
|
run_id=run_id) |
|
edit_figure_creator.run_experiment(False, True, use_multi_id_training=False, use_wandb=False) |
|
|
|
|
|
if __name__ == '__main__': |
|
iid = ''.join(choice(ascii_uppercase) for i in range(7)) |
|
pti_no_comparison(iid) |
|
|