PTI / scripts /latent_editor_wrapper.py
ucalyptus's picture
simp
2d7efb8
import torch
from configs import paths_config
from editings.latent_editor import LatentEditor
class LatentEditorWrapper:
def __init__(self):
self.interfacegan_directions = {'age': f'{paths_config.interfacegan_age}',
'smile': f'{paths_config.interfacegan_smile}',
'rotation': f'{paths_config.interfacegan_rotation}'}
self.interfacegan_directions_tensors = {name: torch.load(path).cuda() for name, path in
self.interfacegan_directions.items()}
self.ganspace_pca = torch.load(f'{paths_config.ffhq_pca}')
## For more edit directions please visit ..
self.ganspace_directions = {
'eye_openness': (54, 7, 8, 5),
'smile': (46, 4, 5, -6),
'trimmed_beard': (58, 7, 9, 7),
}
self.latent_editor = LatentEditor()
def get_single_ganspace_edits(self, start_w, factors):
latents_to_display = []
for ganspace_direction in self.ganspace_directions.values():
for factor in factors:
edit_direction = list(ganspace_direction)
edit_direction[-1] = factor
edit_direction = tuple(edit_direction)
new_w = self.latent_editor.apply_ganspace(start_w, self.ganspace_pca, [edit_direction])
latents_to_display.append(new_w)
return latents_to_display
def get_single_interface_gan_edits(self, start_w, factors):
latents_to_display = {}
for direction in ['rotation', 'smile', 'age']:
for factor in factors:
if direction not in latents_to_display:
latents_to_display[direction] = {}
latents_to_display[direction][factor] = self.latent_editor.apply_interfacegan(
start_w, self.interfacegan_directions_tensors[direction], factor / 2)
return latents_to_display