File size: 7,921 Bytes
e0f92a0
f2e0f04
e0f92a0
bea83f6
 
 
a23872f
e0f92a0
a23872f
 
 
 
 
e0f92a0
 
a23872f
9251ae0
e0f92a0
 
a23872f
 
 
 
e0f92a0
a23872f
e0f92a0
a23872f
 
 
 
 
 
e0f92a0
 
 
a23872f
 
 
 
 
e0f92a0
a23872f
e0f92a0
 
 
 
 
 
 
 
 
 
da5d141
bea83f6
da5d141
21d692f
bea83f6
 
 
 
e0f92a0
 
21d692f
e0f92a0
21d692f
 
bea83f6
 
e0f92a0
a23872f
 
 
 
 
e0f92a0
a23872f
 
6859b0d
a23872f
e0f92a0
 
a23872f
 
e0f92a0
663705e
a23872f
 
 
 
 
 
 
e0f92a0
a23872f
ec39fe8
 
a23872f
e0f92a0
 
a23872f
ec39fe8
a23872f
 
e0f92a0
29bbf75
a23872f
 
e0f92a0
 
 
 
 
 
a23872f
 
 
e0f92a0
 
a23872f
 
e0f92a0
a23872f
ec39fe8
a23872f
e0f92a0
a23872f
 
 
e0f92a0
ec39fe8
a23872f
 
e0f92a0
663705e
a23872f
 
e0f92a0
a23872f
 
 
e0f92a0
 
 
 
 
 
 
a23872f
e0f92a0
 
a23872f
e0f92a0
29bbf75
a23872f
e0f92a0
 
 
 
 
 
 
 
a23872f
 
e0f92a0
29bbf75
a23872f
 
e0f92a0
a23872f
 
 
 
e0f92a0
 
 
a23872f
e0f92a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a23872f
e0f92a0
 
 
 
 
 
 
 
 
 
a23872f
 
e0f92a0
 
 
 
 
 
 
 
 
 
 
 
 
6859b0d
a23872f
29bbf75
a23872f
 
 
 
 
 
 
 
f2e0f04
e0f92a0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import numpy as np
import gc
import os
import imageio
import glob
import uuid
from animation import clear_img_dir
from backend import ImagePromptEditor, log
import torch
import torchvision
import wandb
from edit import blend_paths
from img_processing import custom_to_pil
from PIL import Image

num = 0


class PromptTransformHistory:
    def __init__(self, iterations) -> None:
        self.iterations = iterations
        self.transforms = []


class ImageState:
    def __init__(self, vqgan, prompt_optimizer: ImagePromptEditor) -> None:
        self.vqgan = vqgan
        self.device = vqgan.device
        self.blend_latent = None
        self.quant = True
        self.path1 = None
        self.path2 = None
        self.img_dir = "./img_history"
        if not os.path.exists(self.img_dir):
            os.mkdir(self.img_dir)
        self.transform_history = []
        self.attn_mask = None
        self.prompt_optim = prompt_optimizer
        self._load_vectors()
        self.init_transforms()

    def _load_vectors(self):
        self.lip_vector = torch.load(
            "./latent_vectors/lipvector.pt", map_location=self.device
        )
        self.blue_eyes_vector = torch.load(
            "./latent_vectors/2blue_eyes.pt", map_location=self.device
        )
        self.asian_vector = torch.load(
            "./latent_vectors/asian10.pt", map_location=self.device
        )

    def create_gif(self, total_duration, extend_frames, gif_name="face_edit.gif"):
        images = []
        paths = list(sorted(glob.glob(self.img_dir + "/*")))
        print(paths)
        frame_duration = total_duration / len(paths)
        print(len(paths), "frame dur", frame_duration)
        durations = [frame_duration] * len(paths)
        if extend_frames:
            durations[0] = 1.5
            durations[-1] = 3
        for file_name in paths:
            if file_name.endswith(".png"):
                print(file_name)
                images.append(imageio.imread(file_name))
        imageio.mimsave(gif_name, images, duration=durations)
        return gif_name

    def init_transforms(self):
        self.blue_eyes = torch.zeros_like(self.lip_vector)
        self.lip_size = torch.zeros_like(self.lip_vector)
        self.asian_transform = torch.zeros_like(self.lip_vector)
        self.current_prompt_transforms = [torch.zeros_like(self.lip_vector)]

    def clear_transforms(self):
        self.init_transforms()
        clear_img_dir("./img_history")
        return self._render_all_transformations()

    def _latent_to_pil(self, latent):
        current_im = self.vqgan.decode(latent.to(self.device))[0]
        return custom_to_pil(current_im)

    def _get_mask(self, img, mask=None):
        if img and "mask" in img and img["mask"] is not None:
            attn_mask = torchvision.transforms.ToTensor()(img["mask"])
            attn_mask = torch.ceil(attn_mask[0].to(self.device))
            print("mask set successfully")
        else:
            attn_mask = mask
        return attn_mask

    def set_mask(self, img):
        self.attn_mask = self._get_mask(img)
        x = self.attn_mask.clone()
        x = x.detach().cpu()
        x = torch.clamp(x, -1.0, 1.0)
        x = (x + 1.0) / 2.0
        x = x.numpy()
        x = (255 * x).astype(np.uint8)
        x = Image.fromarray(x, "L")
        return x

    @torch.no_grad()
    def _render_all_transformations(self, return_twice=True):
        global num
        current_vector_transforms = (
            self.blue_eyes,
            self.lip_size,
            self.asian_transform,
            sum(self.current_prompt_transforms),
        )
        new_latent = self.blend_latent + sum(current_vector_transforms)
        if self.quant:
            new_latent, _, _ = self.vqgan.quantize(new_latent.to(self.device))
        image = self._latent_to_pil(new_latent)
        image.save(f"{self.img_dir}/img_{num:06}.png")
        num += 1
        return (image, image) if return_twice else image

    def apply_rb_vector(self, weight):
        self.blue_eyes = weight * self.blue_eyes_vector
        return self._render_all_transformations()

    def apply_lip_vector(self, weight):
        self.lip_size = weight * self.lip_vector
        return self._render_all_transformations()

    def update_quant(self, val):
        self.quant = val
        return self._render_all_transformations()

    def apply_asian_vector(self, weight):
        self.asian_transform = weight * self.asian_vector
        return self._render_all_transformations()

    def update_images(self, path1, path2, blend_weight):
        if path1 is None and path2 is None:
            return None

        # Duplicate paths if one is empty
        if path1 is None:
            path1 = path2
        if path2 is None:
            path2 = path1

        self.path1, self.path2 = path1, path2
        if self.img_dir:
            clear_img_dir(self.img_dir)
        return self.blend(blend_weight)

    @torch.no_grad()
    def blend(self, weight):
        _, latent = blend_paths(
            self.vqgan,
            self.path1,
            self.path2,
            weight=weight,
            show=False,
            device=self.device,
        )
        self.blend_latent = latent
        return self._render_all_transformations()

    @torch.no_grad()
    def rewind(self, index):
        if not self.transform_history:
            print("No history")
            return self._render_all_transformations()
        prompt_transform = self.transform_history[-1]
        latent_index = int(index / 100 * (prompt_transform.iterations - 1))
        print(latent_index)
        self.current_prompt_transforms[-1] = prompt_transform.transforms[
            latent_index
        ].to(self.device)
        return self._render_all_transformations()

    def _init_logging(lr, iterations, lpips_weight, positive_prompts, negative_prompts):
        wandb.init(reinit=True, project="face-editor")
        wandb.config.update({"Positive Prompts": positive_prompts})
        wandb.config.update({"Negative Prompts": negative_prompts})
        wandb.config.update(
            dict(lr=lr, iterations=iterations, lpips_weight=lpips_weight)
        )

    def apply_prompts(
        self,
        positive_prompts,
        negative_prompts,
        lr,
        iterations,
        lpips_weight,
        reconstruction_steps,
    ):
        if log:
            self._init_logging(
                lr, iterations, lpips_weight, positive_prompts, negative_prompts
            )
        transform_log = PromptTransformHistory(iterations + reconstruction_steps)
        transform_log.transforms.append(
            torch.zeros_like(self.blend_latent, requires_grad=False)
        )
        self.current_prompt_transforms.append(
            torch.zeros_like(self.blend_latent, requires_grad=False)
        )
        positive_prompts = [prompt.strip() for prompt in positive_prompts.split("|")]
        negative_prompts = [prompt.strip() for prompt in negative_prompts.split("|")]
        self.prompt_optim.set_params(
            lr,
            iterations,
            lpips_weight,
            attn_mask=self.attn_mask,
            reconstruction_steps=reconstruction_steps,
        )

        for i, transform in enumerate(
            self.prompt_optim.optimize(
                self.blend_latent, positive_prompts, negative_prompts
            )
        ):
            transform_log.transforms.append(transform.detach().cpu())
            self.current_prompt_transforms[-1] = transform
            with torch.no_grad():
                image = self._render_all_transformations(return_twice=False)
            if log:
                wandb.log({"image": wandb.Image(image)})
            yield (image, image)
        if log:
            wandb.finish()
        self.attn_mask = None
        self.transform_history.append(transform_log)
        gc.collect()
        torch.cuda.empty_cache()