aijack commited on
Commit
a89e8c6
1 Parent(s): cbf51d9

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import random
4
+ import torch
5
+ import gradio as gr
6
+
7
+ from e4e.models.psp import pSp
8
+ from util import *
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ import tempfile
12
+ from argparse import Namespace
13
+ import shutil
14
+
15
+ import dlib
16
+ import numpy as np
17
+ import torchvision.transforms as transforms
18
+ from torchvision import utils
19
+
20
+ from model.sg2_model import Generator
21
+ from generate_videos import project_code_by_edit_name
22
+
23
+ import clip
24
+ import urllib.request
25
+
26
+ model_dir = "models"
27
+ os.makedirs(model_dir, exist_ok=True)
28
+
29
+ model_repos = {
30
+ "e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
31
+ "dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
32
+ "base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt"),
33
+ "sketch": ("rinong/stylegan-nada-models", "sketch.pt"),
34
+ "santa": ("mjdolan/stylegan-nada-models", "santa.pt"),
35
+ "jesus": ("mjdolan/stylegan-nada-models", "jesus.pt"),
36
+ "mariah": ("mjdolan/stylegan-nada-models", "mariah.pt"),
37
+ "heat_miser": ("mjdolan/stylegan-nada-models", "heat.pt"),
38
+ "claymation": ("mjdolan/stylegan-nada-models", "claymation.pt"),
39
+ "elf": ("mjdolan/stylegan-nada-models", "elf.pt"),
40
+ "krampus": ("mjdolan/stylegan-nada-models", "krampus.pt"),
41
+ "grinch": ("mjdolan/stylegan-nada-models", "grinch.pt"),
42
+ "jack_frost": ("mjdolan/stylegan-nada-models", "jack_frost.pt"),
43
+ "rudolph": ("mjdolan/stylegan-nada-models", "rudolph.pt"),
44
+ "home_alone": ("mjdolan/stylegan-nada-models", "home_alone.pt"),
45
+ "puppet":("rinong/stylegan-nada-models", "plastic_puppet.pt"),
46
+ "crochet": ("rinong/stylegan-nada-models", "crochet.pt"),
47
+ "shrek": ("rinong/stylegan-nada-models", "shrek.pt"),
48
+ "pixar": ("rinong/stylegan-nada-models", "pixar.pt")
49
+ }
50
+
51
+ interface_gan_map = {"None": None, "Masculine": ("gender", 1.0), "Feminine": ("gender", -1.0),
52
+ "Smiling": ("smile", 1.0),
53
+ "Frowning": ("smile", -1.0), "Young": ("age", -1.0), "Old": ("age", 1.0),
54
+ "Long Hair": ("hair_length", -1.0), "Short Hair": ("hair_length", 1.0)}
55
+
56
+
57
+ def get_models():
58
+ os.makedirs(model_dir, exist_ok=True)
59
+
60
+ model_paths = {}
61
+
62
+ for model_name, repo_details in model_repos.items():
63
+ download_path = hf_hub_download(repo_id=repo_details[0], filename=repo_details[1])
64
+ model_paths[model_name] = download_path
65
+
66
+ return model_paths
67
+
68
+
69
+ model_paths = get_models()
70
+
71
+
72
+ class ImageEditor(object):
73
+ def __init__(self):
74
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
75
+
76
+ latent_size = 512
77
+ n_mlp = 8
78
+ channel_mult = 2
79
+ model_size = 1024
80
+
81
+ self.generators = {}
82
+
83
+ self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
84
+
85
+ for model in self.model_list:
86
+ g_ema = Generator(
87
+ model_size, latent_size, n_mlp, channel_multiplier=channel_mult
88
+ ).to(self.device)
89
+
90
+ checkpoint = torch.load(model_paths[model], map_location=self.device)
91
+
92
+ g_ema.load_state_dict(checkpoint['g_ema'])
93
+
94
+ self.generators[model] = g_ema
95
+
96
+ self.experiment_args = {"model_path": model_paths["e4e"]}
97
+ self.experiment_args["transform"] = transforms.Compose(
98
+ [
99
+ transforms.Resize((256, 256)),
100
+ transforms.ToTensor(),
101
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
102
+ ]
103
+ )
104
+ self.resize_dims = (256, 256)
105
+
106
+ model_path = self.experiment_args["model_path"]
107
+
108
+ ckpt = torch.load(model_path, map_location="cuda:0" if torch.cuda.is_available() else "cpu")
109
+ opts = ckpt["opts"]
110
+
111
+ opts["checkpoint_path"] = model_path
112
+ opts = Namespace(**opts)
113
+
114
+ self.e4e_net = pSp(opts, self.device)
115
+ self.e4e_net.eval()
116
+
117
+ self.shape_predictor = dlib.shape_predictor(
118
+ model_paths["dlib"]
119
+ )
120
+
121
+
122
+ self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
123
+
124
+ print("setup complete")
125
+
126
+ def get_style_list(self):
127
+ style_list = []
128
+
129
+ for key in self.generators:
130
+ style_list.append(key)
131
+
132
+ return style_list
133
+
134
+ def invert_image(self, input_image):
135
+ input_image = self.run_alignment(str(input_image))
136
+
137
+ input_image = input_image.resize(self.resize_dims)
138
+
139
+ img_transforms = self.experiment_args["transform"]
140
+ transformed_image = img_transforms(input_image)
141
+
142
+ with torch.no_grad():
143
+ images, latents = self.run_on_batch(transformed_image.unsqueeze(0))
144
+ result_image, latent = images[0], latents[0]
145
+
146
+ inverted_latent = latent.unsqueeze(0).unsqueeze(1)
147
+
148
+ return inverted_latent
149
+
150
+ def get_generators_for_styles(self, output_styles, loop_styles=False):
151
+
152
+ if "base" in output_styles: # always start with base if chosen
153
+ output_styles.insert(0, output_styles.pop(output_styles.index("base")))
154
+ if loop_styles:
155
+ output_styles.append(output_styles[0])
156
+
157
+ return [self.generators[style] for style in output_styles]
158
+
159
+
160
+
161
+ def get_target_latent(self, source_latent, alter, generators):
162
+ np_source_latent = source_latent.squeeze(0).cpu().detach().numpy()
163
+ if alter == "None":
164
+ return random.choice([source_latent.squeeze(0),] * max((len(generators) - 1), 1))
165
+ edit = interface_gan_map[alter]
166
+ projected_code_np = project_code_by_edit_name(np_source_latent, edit[0], edit[1])
167
+ return torch.from_numpy(projected_code_np).float().to(self.device)
168
+
169
+ def edit_image(self, input, output_styles, edit_choices):
170
+ return self.predict(input, output_styles, edit_choices=edit_choices)
171
+
172
+ def predict(
173
+ self,
174
+ input, # Input image path
175
+ output_styles, # Style checkbox options.
176
+ loop_styles=False, # Loop back to the initial style
177
+ edit_choices=None, # Optional dictionary with edit choice arguments
178
+ ):
179
+
180
+ if edit_choices is None:
181
+ edit_choices = {"edit_type": "None"}
182
+
183
+ # @title Align image
184
+ out_dir = tempfile.mkdtemp()
185
+
186
+ inverted_latent = self.invert_image(input)
187
+ generators = self.get_generators_for_styles(output_styles, loop_styles)
188
+ output_paths = []
189
+
190
+ with torch.no_grad():
191
+ for g_ema in generators:
192
+ latent_for_gen = self.get_target_latent(inverted_latent, edit_choices, generators)
193
+
194
+ img, _ = g_ema([latent_for_gen], input_is_latent=True, truncation=1, randomize_noise=False)
195
+
196
+ output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
197
+ utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
198
+
199
+ output_paths.append(output_path)
200
+
201
+ return output_paths
202
+
203
+
204
+ def run_alignment(self, image_path):
205
+ aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor)
206
+ print("Aligned image has shape: {}".format(aligned_image.size))
207
+ return aligned_image
208
+
209
+ def run_on_batch(self, inputs):
210
+ images, latents = self.e4e_net(
211
+ inputs.to(self.device).float(), randomize_noise=False, return_latents=True
212
+ )
213
+ return images, latents
214
+
215
+
216
+ editor = ImageEditor()
217
+ # Fetch image for analysis
218
+ img_url = "http://claireye.com.tw/img/230212a.jpg"
219
+ urllib.request.urlretrieve(img_url, "pose.jpg")
220
+ blocks = gr.Blocks(theme="darkdefault")
221
+
222
+ with blocks:
223
+ gr.Markdown("<h1><center>Holiday Filters (StyleGAN-NADA)</center></h1>")
224
+ gr.Markdown(
225
+ "<div>Upload an image of your face, pick your desired output styles, pick any modifiers, and apply StyleGAN-based editing.</div>"
226
+ )
227
+ with gr.Row():
228
+ with gr.Column():
229
+ input_img = gr.Image(type="filepath", label="Input image")
230
+ with gr.Column():
231
+ style_choice = gr.CheckboxGroup(choices=editor.get_style_list(), value=editor.get_style_list(), type="value", label="Styles")
232
+ alter = gr.Dropdown(
233
+ choices=["None", "Masculine", "Feminine", "Smiling", "Frowning", "Young", "Old", "Short Hair",
234
+ "Long Hair"], value="None", label="Additional Modifiers")
235
+ img_button = gr.Button("Edit Image")
236
+
237
+ with gr.Row():
238
+ img_output = gr.Gallery(label="Output Images")
239
+ img_output.style(grid=(3, 3, 4, 4, 6, 6))
240
+
241
+ img_button.click(fn=editor.edit_image, inputs=[input_img, style_choice, alter], outputs=img_output)
242
+ ex = gr.Examples(examples=[['pose.jpg', editor.get_style_list(), "Smiling"], ['pose.jpg', editor.get_style_list(), "Long Hair"]], fn=editor.edit_image, inputs=[input_img, style_choice, alter],
243
+ outputs=[img_output], cache_examples=True,
244
+ run_on_click=True)
245
+ ex.dataset.headers = [""]
246
+ article = "<p style='text-align: center'><a href='http://claireye.com.tw'>Claireye</a> | 2023</p>"
247
+ gr.Markdown(article)
248
+
249
+ blocks.launch(enable_queue=True)