File size: 1,988 Bytes
2d7efb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from configs import paths_config
from editings.latent_editor import LatentEditor


class LatentEditorWrapper:

    def __init__(self):

        self.interfacegan_directions = {'age': f'{paths_config.interfacegan_age}',
                                        'smile': f'{paths_config.interfacegan_smile}',
                                        'rotation': f'{paths_config.interfacegan_rotation}'}
        self.interfacegan_directions_tensors = {name: torch.load(path).cuda() for name, path in
                                                self.interfacegan_directions.items()}
        self.ganspace_pca = torch.load(f'{paths_config.ffhq_pca}')

        ## For more edit directions please visit ..
        self.ganspace_directions = {
            'eye_openness': (54, 7, 8, 5),
            'smile': (46, 4, 5, -6),
            'trimmed_beard': (58, 7, 9, 7),
        }

        self.latent_editor = LatentEditor()

    def get_single_ganspace_edits(self, start_w, factors):
        latents_to_display = []
        for ganspace_direction in self.ganspace_directions.values():
            for factor in factors:
                edit_direction = list(ganspace_direction)
                edit_direction[-1] = factor
                edit_direction = tuple(edit_direction)
                new_w = self.latent_editor.apply_ganspace(start_w, self.ganspace_pca, [edit_direction])
                latents_to_display.append(new_w)
        return latents_to_display

    def get_single_interface_gan_edits(self, start_w, factors):
        latents_to_display = {}
        for direction in ['rotation', 'smile', 'age']:
            for factor in factors:
                if direction not in latents_to_display:
                    latents_to_display[direction] = {}
                latents_to_display[direction][factor] = self.latent_editor.apply_interfacegan(
                    start_w, self.interfacegan_directions_tensors[direction], factor / 2)

        return latents_to_display