File size: 3,964 Bytes
b6068b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import os
import numpy as np
import imgui
import dnnlib
import torch
from gui_utils import imgui_utils

#----------------------------------------------------------------------------

class LatentWidget:
    def __init__(self, viz):
        self.viz        = viz
        self.seed       = 0
        self.w_plus     = True
        self.reg        = 0
        self.lr         = 0.001
        self.w_path     = ''
        self.w_load     = None
        self.defer_frames   = 0
        self.disabled_time  = 0

    @imgui_utils.scoped_by_object_id
    def __call__(self, show=True):
        viz = self.viz
        if show:
            with imgui_utils.grayed_out(self.disabled_time != 0):
                imgui.text('Latent')
                imgui.same_line(viz.label_w)
                with imgui_utils.item_width(viz.font_size * 8.75):
                    changed, seed = imgui.input_int('Seed', self.seed)
                    if changed:
                        self.seed = seed
                        # reset latent code
                        self.w_load = None

                # load latent code
                imgui.text(' ')
                imgui.same_line(viz.label_w)
                _changed, self.w_path = imgui_utils.input_text('##path', self.w_path, 1024,
                    flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE),
                    width=(-1),
                    help_text='Path to latent code')
                if imgui.is_item_hovered() and not imgui.is_item_active() and self.w_path != '':
                    imgui.set_tooltip(self.w_path)

                imgui.text(' ')
                imgui.same_line(viz.label_w)
                if imgui_utils.button('Load latent', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)):
                    assert os.path.isfile(self.w_path), f"{self.w_path} does not exist!"
                    self.w_load = torch.load(self.w_path)
                    self.defer_frames = 2
                    self.disabled_time = 0.5
                
                imgui.text(' ')
                imgui.same_line(viz.label_w)
                with imgui_utils.item_width(viz.button_w):
                    changed, lr = imgui.input_float('Step Size', self.lr)
                    if changed:
                        self.lr = lr

                # imgui.text(' ')
                # imgui.same_line(viz.label_w)
                # with imgui_utils.item_width(viz.button_w):
                #     changed, reg = imgui.input_float('Regularize', self.reg)
                #     if changed:
                #         self.reg = reg

                imgui.text(' ')
                imgui.same_line(viz.label_w)
                reset_w = imgui_utils.button('Reset', width=viz.button_w, enabled='image' in viz.result)
                imgui.same_line()
                _clicked, w = imgui.checkbox('w', not self.w_plus)
                if w:
                    self.w_plus = False
                imgui.same_line()
                _clicked, self.w_plus = imgui.checkbox('w+', self.w_plus)

        self.disabled_time = max(self.disabled_time - viz.frame_delta, 0)
        if self.defer_frames > 0:
            self.defer_frames -= 1
        viz.args.w0_seed = self.seed
        viz.args.w_load = self.w_load
        viz.args.reg = self.reg
        viz.args.w_plus = self.w_plus
        viz.args.reset_w = reset_w
        viz.args.lr = lr

#----------------------------------------------------------------------------