hysts HF staff commited on
Commit
73099b9
1 Parent(s): 6ee2510

Update to the latest blocks version

Browse files
Files changed (5) hide show
  1. .pre-commit-config.yaml +35 -0
  2. .style.yapf +5 -0
  3. app.py +197 -262
  4. dualstylegan.py +166 -0
  5. style.css +17 -0
.pre-commit-config.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v4.2.0
4
+ hooks:
5
+ - id: check-executables-have-shebangs
6
+ - id: check-json
7
+ - id: check-merge-conflict
8
+ - id: check-shebang-scripts-are-executable
9
+ - id: check-toml
10
+ - id: check-yaml
11
+ - id: double-quote-string-fixer
12
+ - id: end-of-file-fixer
13
+ - id: mixed-line-ending
14
+ args: ['--fix=lf']
15
+ - id: requirements-txt-fixer
16
+ - id: trailing-whitespace
17
+ - repo: https://github.com/myint/docformatter
18
+ rev: v1.4
19
+ hooks:
20
+ - id: docformatter
21
+ args: ['--in-place']
22
+ - repo: https://github.com/pycqa/isort
23
+ rev: 5.10.1
24
+ hooks:
25
+ - id: isort
26
+ - repo: https://github.com/pre-commit/mirrors-mypy
27
+ rev: v0.812
28
+ hooks:
29
+ - id: mypy
30
+ args: ['--ignore-missing-imports']
31
+ - repo: https://github.com/google/yapf
32
+ rev: v0.32.0
33
+ hooks:
34
+ - id: yapf
35
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
app.py CHANGED
@@ -3,294 +3,229 @@
3
  from __future__ import annotations
4
 
5
  import argparse
6
- import functools
7
- import os
8
  import pathlib
9
- import sys
10
- from typing import Callable
11
 
12
- import dlib
13
  import gradio as gr
14
- import huggingface_hub
15
- import numpy as np
16
- import PIL.Image
17
- import torch
18
- import torch.nn as nn
19
- import torchvision.transforms as T
20
 
21
- if os.environ.get('SYSTEM') == 'spaces':
22
- os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
23
- os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
24
 
25
- sys.path.insert(0, 'DualStyleGAN')
26
 
27
- from model.dualstylegan import DualStyleGAN
28
- from model.encoder.align_all_parallel import align_face
29
- from model.encoder.psp import pSp
30
-
31
- TITLE = 'williamyang1991/DualStyleGAN'
32
- DESCRIPTION = '''This is an unofficial demo for https://github.com/williamyang1991/DualStyleGAN.
33
-
34
- ![overview](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg)
35
 
36
- You can select style images for each style type from the tables below.
37
- The style image index should be in the following range:
38
- (cartoon: 0-316, caricature: 0-198, anime: 0-173, arcane: 0-99, comic: 0-100, pixar: 0-121, slamdunk: 0-119)
39
 
40
- Expected execution time on Hugging Face Spaces: 15s
41
- '''
42
- ARTICLE = '''## Style images
 
 
 
 
 
 
 
43
 
44
- Note that the style images here for Arcane, comic, Pixar, and Slamdunk are the reconstructed ones, not the original ones due to copyright issues.
45
 
46
- ### Cartoon
47
- ![cartoon style images](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/cartoon_overview.jpg)
 
 
 
 
 
 
 
 
 
 
48
 
49
- ### Caricature
50
- ![caricature style images](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/caricature_overview.jpg)
51
 
52
- ### Anime
53
- ![anime style images](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/anime_overview.jpg)
 
54
 
55
- ### Arcane
56
- ![arcane style images](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_arcane_overview.jpg)
57
 
58
- ### Comic
59
- ![comic style images](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_comic_overview.jpg)
 
 
 
 
 
 
 
 
 
60
 
61
- ### Pixar
62
- ![pixar style images](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_pixar_overview.jpg)
63
 
64
- ### Slamdunk
65
- ![slamdunk style images](https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/Reconstruction_slamdunk_overview.jpg)
 
66
 
67
- <center><img src="https://visitor-badge.glitch.me/badge?page_id=hysts.dualstylegan" alt="visitor badge"/></center>
68
- '''
69
 
70
- TOKEN = os.environ['TOKEN']
71
- MODEL_REPO = 'hysts/DualStyleGAN'
72
 
73
 
74
- def parse_args() -> argparse.Namespace:
75
- parser = argparse.ArgumentParser()
76
- parser.add_argument('--device', type=str, default='cpu')
77
- parser.add_argument('--theme', type=str)
78
- parser.add_argument('--live', action='store_true')
79
- parser.add_argument('--share', action='store_true')
80
- parser.add_argument('--port', type=int)
81
- parser.add_argument('--disable-queue',
82
- dest='enable_queue',
83
- action='store_false')
84
- parser.add_argument('--allow-flagging', type=str, default='never')
85
- return parser.parse_args()
86
 
87
 
88
- def load_encoder(device: torch.device) -> nn.Module:
89
- ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
90
- 'models/encoder.pt',
91
- use_auth_token=TOKEN)
92
- ckpt = torch.load(ckpt_path, map_location='cpu')
93
- opts = ckpt['opts']
94
- opts['device'] = device.type
95
- opts['checkpoint_path'] = ckpt_path
96
- opts = argparse.Namespace(**opts)
97
- model = pSp(opts)
98
- model.to(device)
99
- model.eval()
100
- return model
101
-
102
-
103
- def load_generator(style_type: str, device: torch.device) -> nn.Module:
104
- model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
105
- ckpt_path = huggingface_hub.hf_hub_download(
106
- MODEL_REPO, f'models/{style_type}/generator.pt', use_auth_token=TOKEN)
107
- ckpt = torch.load(ckpt_path, map_location='cpu')
108
- model.load_state_dict(ckpt['g_ema'])
109
- model.to(device)
110
- model.eval()
111
- return model
112
-
113
-
114
- def load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
115
- if style_type in ['cartoon', 'caricature', 'anime']:
116
- filename = 'refined_exstyle_code.npy'
117
- else:
118
- filename = 'exstyle_code.npy'
119
- path = huggingface_hub.hf_hub_download(MODEL_REPO,
120
- f'models/{style_type}/{filename}',
121
- use_auth_token=TOKEN)
122
- exstyles = np.load(path, allow_pickle=True).item()
123
- return exstyles
124
-
125
-
126
- def create_transform() -> Callable:
127
- transform = T.Compose([
128
- T.Resize(256),
129
- T.CenterCrop(256),
130
- T.ToTensor(),
131
- T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
132
- ])
133
- return transform
134
-
135
-
136
- def create_dlib_landmark_model():
137
- path = huggingface_hub.hf_hub_download(
138
- 'hysts/dlib_face_landmark_model',
139
- 'shape_predictor_68_face_landmarks.dat',
140
- use_auth_token=TOKEN)
141
- return dlib.shape_predictor(path)
142
-
143
-
144
- def denormalize(tensor: torch.Tensor) -> torch.Tensor:
145
- return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
146
-
147
-
148
- def postprocess(tensor: torch.Tensor) -> PIL.Image.Image:
149
- tensor = denormalize(tensor)
150
- image = tensor.cpu().numpy().transpose(1, 2, 0)
151
- return PIL.Image.fromarray(image)
152
-
153
-
154
- @torch.inference_mode()
155
- def run(
156
- image,
157
- style_type: str,
158
- style_id: float,
159
- structure_weight: float,
160
- color_weight: float,
161
- dlib_landmark_model,
162
- encoder: nn.Module,
163
- generator_dict: dict[str, nn.Module],
164
- exstyle_dict: dict[str, dict[str, np.ndarray]],
165
- transform: Callable,
166
- device: torch.device,
167
- ) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image,
168
- PIL.Image.Image]:
169
- generator = generator_dict[style_type]
170
- exstyles = exstyle_dict[style_type]
171
-
172
- style_id = int(style_id)
173
- style_id = min(max(0, style_id), len(exstyles) - 1)
174
-
175
- stylename = list(exstyles.keys())[style_id]
176
-
177
- image = align_face(filepath=image.name, predictor=dlib_landmark_model)
178
- input_data = transform(image).unsqueeze(0).to(device)
179
-
180
- img_rec, instyle = encoder(input_data,
181
- randomize_noise=False,
182
- return_latents=True,
183
- z_plus_latent=True,
184
- return_z_plus_latent=True,
185
- resize=False)
186
- img_rec = torch.clamp(img_rec.detach(), -1, 1)
187
-
188
- latent = torch.tensor(exstyles[stylename]).repeat(2, 1, 1).to(device)
189
- # latent[0] for both color and structrue transfer and latent[1] for only structrue transfer
190
- latent[1, 7:18] = instyle[0, 7:18]
191
- exstyle = generator.generator.style(
192
- latent.reshape(latent.shape[0] * latent.shape[1],
193
- latent.shape[2])).reshape(latent.shape)
194
-
195
- img_gen, _ = generator([instyle.repeat(2, 1, 1)],
196
- exstyle,
197
- z_plus_latent=True,
198
- truncation=0.7,
199
- truncation_latent=0,
200
- use_res=True,
201
- interp_weights=[structure_weight] * 7 +
202
- [color_weight] * 11)
203
- img_gen = torch.clamp(img_gen.detach(), -1, 1)
204
- # deactivate color-related layers by setting w_c = 0
205
- img_gen2, _ = generator([instyle],
206
- exstyle[0:1],
207
- z_plus_latent=True,
208
- truncation=0.7,
209
- truncation_latent=0,
210
- use_res=True,
211
- interp_weights=[structure_weight] * 7 + [0] * 11)
212
- img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
213
-
214
- img_rec = postprocess(img_rec[0])
215
- img_gen0 = postprocess(img_gen[0])
216
- img_gen1 = postprocess(img_gen[1])
217
- img_gen2 = postprocess(img_gen2[0])
218
-
219
- return image, img_rec, img_gen0, img_gen1, img_gen2
220
 
221
 
222
  def main():
223
  args = parse_args()
224
- device = torch.device(args.device)
225
-
226
- style_types = [
227
- 'cartoon',
228
- 'caricature',
229
- 'anime',
230
- 'arcane',
231
- 'comic',
232
- 'pixar',
233
- 'slamdunk',
234
- ]
235
- generator_dict = {
236
- style_type: load_generator(style_type, device)
237
- for style_type in style_types
238
- }
239
- exstyle_dict = {
240
- style_type: load_exstylecode(style_type)
241
- for style_type in style_types
242
- }
243
-
244
- dlib_landmark_model = create_dlib_landmark_model()
245
- encoder = load_encoder(device)
246
- transform = create_transform()
247
-
248
- func = functools.partial(run,
249
- dlib_landmark_model=dlib_landmark_model,
250
- encoder=encoder,
251
- generator_dict=generator_dict,
252
- exstyle_dict=exstyle_dict,
253
- transform=transform,
254
- device=device)
255
- func = functools.update_wrapper(func, run)
256
-
257
- image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
258
- examples = [[path.as_posix(), 'cartoon', 26, 0.6, 1.0]
259
- for path in image_paths]
260
-
261
- gr.Interface(
262
- func,
263
- [
264
- gr.inputs.Image(type='file', label='Input Image'),
265
- gr.inputs.Radio(style_types,
266
- type='value',
267
- default='cartoon',
268
- label='Style Type'),
269
- gr.inputs.Number(default=26, label='Style Image Index'),
270
- gr.inputs.Slider(
271
- 0, 1, step=0.1, default=0.6, label='Structure Weight'),
272
- gr.inputs.Slider(0, 1, step=0.1, default=1.0,
273
- label='Color Weight'),
274
- ],
275
- [
276
- gr.outputs.Image(type='pil', label='Aligned Face'),
277
- gr.outputs.Image(type='pil', label='Reconstructed'),
278
- gr.outputs.Image(type='pil',
279
- label='Result 1 (Color and structure transfer)'),
280
- gr.outputs.Image(type='pil',
281
- label='Result 2 (Structure transfer only)'),
282
- gr.outputs.Image(
283
- type='pil',
284
- label='Result 3 (Color-related layers deactivated)'),
285
- ],
286
- examples=examples,
287
- title=TITLE,
288
- description=DESCRIPTION,
289
- article=ARTICLE,
290
- theme=args.theme,
291
- allow_flagging=args.allow_flagging,
292
- live=args.live,
293
- ).launch(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  enable_queue=args.enable_queue,
295
  server_port=args.port,
296
  share=args.share,
3
  from __future__ import annotations
4
 
5
  import argparse
 
 
6
  import pathlib
 
 
7
 
 
8
  import gradio as gr
 
 
 
 
 
 
9
 
10
+ from dualstylegan import Model
 
 
11
 
12
+ DESCRIPTION = '''# Portrait Style Transfer with <a href="https://github.com/williamyang1991/DualStyleGAN">DualStyleGAN</a>
13
 
14
+ <img id="overview" alt="overview" src="https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images/overview.jpg" />
15
+ '''
16
+ FOOTER = '<img id="visitor-badge" alt="visitor badge" src="https://visitor-badge.glitch.me/badge?page_id=gradio-blocks.dualstylegan" />'
 
 
 
 
 
17
 
 
 
 
18
 
19
+ def parse_args() -> argparse.Namespace:
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('--device', type=str, default='cpu')
22
+ parser.add_argument('--theme', type=str)
23
+ parser.add_argument('--share', action='store_true')
24
+ parser.add_argument('--port', type=int)
25
+ parser.add_argument('--disable-queue',
26
+ dest='enable_queue',
27
+ action='store_false')
28
+ return parser.parse_args()
29
 
 
30
 
31
+ def get_style_image_url(style_name: str) -> str:
32
+ base_url = 'https://raw.githubusercontent.com/williamyang1991/DualStyleGAN/main/doc_images'
33
+ filenames = {
34
+ 'cartoon': 'cartoon_overview.jpg',
35
+ 'caricature': 'caricature_overview.jpg',
36
+ 'anime': 'anime_overview.jpg',
37
+ 'arcane': 'Reconstruction_arcane_overview.jpg',
38
+ 'comic': 'Reconstruction_comic_overview.jpg',
39
+ 'pixar': 'Reconstruction_pixar_overview.jpg',
40
+ 'slamdunk': 'Reconstruction_slamdunk_overview.jpg',
41
+ }
42
+ return f'{base_url}/{filenames[style_name]}'
43
 
 
 
44
 
45
+ def get_style_image_markdown_text(style_name: str) -> str:
46
+ url = get_style_image_url(style_name)
47
+ return f'<center><img id="style-image" src="{url}" alt="style image"></center>'
48
 
 
 
49
 
50
+ def update_slider(choice: str) -> dict:
51
+ max_vals = {
52
+ 'cartoon': 316,
53
+ 'caricature': 198,
54
+ 'anime': 173,
55
+ 'arcane': 99,
56
+ 'comic': 100,
57
+ 'pixar': 121,
58
+ 'slamdunk': 119,
59
+ }
60
+ return gr.Slider.update(maximum=max_vals[choice])
61
 
 
 
62
 
63
+ def update_style_image(style_name: str) -> dict:
64
+ text = get_style_image_markdown_text(style_name)
65
+ return gr.Markdown.update(value=text)
66
 
 
 
67
 
68
+ def set_example_image(example: list) -> dict:
69
+ return gr.Image.update(value=example[0])
70
 
71
 
72
+ def set_example_styles(example: list) -> list[dict]:
73
+ return [
74
+ gr.Radio.update(value=example[0]),
75
+ gr.Slider.update(value=example[1]),
76
+ ]
 
 
 
 
 
 
 
77
 
78
 
79
+ def set_example_weights(example: list) -> list[dict]:
80
+ return [
81
+ gr.Slider.update(value=example[0]),
82
+ gr.Slider.update(value=example[1]),
83
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  def main():
87
  args = parse_args()
88
+ model = Model(device=args.device)
89
+
90
+ with gr.Blocks(theme=args.theme, css='style.css') as demo:
91
+ gr.Markdown(DESCRIPTION)
92
+
93
+ with gr.Box():
94
+ gr.Markdown('''## Step 1 (Preprocess Input Image)
95
+
96
+ - Drop an image containing a near-frontal face to the **Input Image**.
97
+ - If there are multiple faces in the image, hit the Edit button in the upper right corner and crop the input image beforehand.
98
+ - Hit the **Detect & Align Face** button.
99
+ - Hit the **Reconstruct Face** button.
100
+ - The final result will be based on this **Reconstructed Face**. So, if the reconstructed image is not satisfactory, you may want to change the input image.
101
+ ''')
102
+ with gr.Row():
103
+ with gr.Column():
104
+ with gr.Row():
105
+ input_image = gr.Image(label='Input Image',
106
+ type='file')
107
+ with gr.Row():
108
+ detect_button = gr.Button('Detect & Align Face')
109
+ with gr.Column():
110
+ with gr.Row():
111
+ aligned_face = gr.Image(label='Aligned Face',
112
+ type='numpy',
113
+ interactive=False)
114
+ with gr.Row():
115
+ reconstruct_button = gr.Button('Reconstruct Face')
116
+ with gr.Column():
117
+ reconstructed_face = gr.Image(label='Reconstructed Face',
118
+ type='numpy')
119
+ instyle = gr.Variable()
120
+
121
+ with gr.Row():
122
+ paths = sorted(pathlib.Path('images').glob('*.jpg'))
123
+ example_images = gr.Dataset(components=[input_image],
124
+ samples=[[path.as_posix()]
125
+ for path in paths])
126
+
127
+ with gr.Box():
128
+ gr.Markdown('''## Step 2 (Select Style Image)
129
+
130
+ - Select **Style Type**.
131
+ - Select **Style Image Index** from the image table below.
132
+ ''')
133
+ with gr.Row():
134
+ with gr.Column():
135
+ style_type = gr.Radio(model.style_types,
136
+ label='Style Type')
137
+ text = get_style_image_markdown_text('cartoon')
138
+ style_image = gr.Markdown(value=text)
139
+ style_index = gr.Slider(0,
140
+ 316,
141
+ value=26,
142
+ step=1,
143
+ label='Style Image Index')
144
+
145
+ with gr.Row():
146
+ example_styles = gr.Dataset(
147
+ components=[style_type, style_index],
148
+ samples=[
149
+ ['cartoon', 26],
150
+ ['caricature', 65],
151
+ ['arcane', 63],
152
+ ['pixar', 80],
153
+ ])
154
+
155
+ with gr.Box():
156
+ gr.Markdown('''## Step 3 (Generate Style Transferred Image)
157
+
158
+ - Adjust **Structure Weight** and **Color Weight**.
159
+ - These are weights for the style image, so the larger the value, the closer the resulting image will be to the style image.
160
+ - Hit the **Generate** button.
161
+ ''')
162
+ with gr.Row():
163
+ with gr.Column():
164
+ with gr.Row():
165
+ structure_weight = gr.Slider(0,
166
+ 1,
167
+ value=0.6,
168
+ step=0.1,
169
+ label='Structure Weight')
170
+ with gr.Row():
171
+ color_weight = gr.Slider(0,
172
+ 1,
173
+ value=1,
174
+ step=0.1,
175
+ label='Color Weight')
176
+ with gr.Row():
177
+ structure_only = gr.Checkbox(label='Structure Only')
178
+ with gr.Row():
179
+ generate_button = gr.Button('Generate')
180
+
181
+ with gr.Column():
182
+ result = gr.Image(label='Result')
183
+
184
+ with gr.Row():
185
+ example_weights = gr.Dataset(
186
+ components=[structure_weight, color_weight],
187
+ samples=[
188
+ [0.6, 1.0],
189
+ [0.3, 1.0],
190
+ [0.0, 1.0],
191
+ [1.0, 0.0],
192
+ ])
193
+
194
+ gr.Markdown(FOOTER)
195
+
196
+ detect_button.click(fn=model.detect_and_align_face,
197
+ inputs=input_image,
198
+ outputs=aligned_face)
199
+ reconstruct_button.click(fn=model.reconstruct_face,
200
+ inputs=aligned_face,
201
+ outputs=[reconstructed_face, instyle])
202
+ style_type.change(fn=update_slider,
203
+ inputs=style_type,
204
+ outputs=style_index)
205
+ style_type.change(fn=update_style_image,
206
+ inputs=style_type,
207
+ outputs=style_image)
208
+ generate_button.click(fn=model.generate,
209
+ inputs=[
210
+ style_type,
211
+ style_index,
212
+ structure_weight,
213
+ color_weight,
214
+ structure_only,
215
+ instyle,
216
+ ],
217
+ outputs=result)
218
+ example_images.click(fn=set_example_image,
219
+ inputs=example_images,
220
+ outputs=example_images.components)
221
+ example_styles.click(fn=set_example_styles,
222
+ inputs=example_styles,
223
+ outputs=example_styles.components)
224
+ example_weights.click(fn=set_example_weights,
225
+ inputs=example_weights,
226
+ outputs=example_weights.components)
227
+
228
+ demo.launch(
229
  enable_queue=args.enable_queue,
230
  server_port=args.port,
231
  share=args.share,
dualstylegan.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import os
5
+ import sys
6
+ from typing import Callable, Union
7
+
8
+ import dlib
9
+ import huggingface_hub
10
+ import numpy as np
11
+ import PIL.Image
12
+ import torch
13
+ import torch.nn as nn
14
+ import torchvision.transforms as T
15
+
16
+ if os.environ.get('SYSTEM') == 'spaces':
17
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
18
+ os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
19
+
20
+ sys.path.insert(0, 'DualStyleGAN')
21
+
22
+ from model.dualstylegan import DualStyleGAN
23
+ from model.encoder.align_all_parallel import align_face
24
+ from model.encoder.psp import pSp
25
+
26
+ HF_TOKEN = os.environ['HF_TOKEN']
27
+ MODEL_REPO = 'hysts/DualStyleGAN'
28
+
29
+
30
+ class Model:
31
+ def __init__(self, device: Union[torch.device, str]):
32
+ self.device = torch.device(device)
33
+ self.landmark_model = self._create_dlib_landmark_model()
34
+ self.encoder = self._load_encoder()
35
+ self.transform = self._create_transform()
36
+
37
+ self.style_types = [
38
+ 'cartoon',
39
+ 'caricature',
40
+ 'anime',
41
+ 'arcane',
42
+ 'comic',
43
+ 'pixar',
44
+ 'slamdunk',
45
+ ]
46
+ self.generator_dict = {
47
+ style_type: self._load_generator(style_type)
48
+ for style_type in self.style_types
49
+ }
50
+ self.exstyle_dict = {
51
+ style_type: self._load_exstylecode(style_type)
52
+ for style_type in self.style_types
53
+ }
54
+
55
+ @staticmethod
56
+ def _create_dlib_landmark_model():
57
+ path = huggingface_hub.hf_hub_download(
58
+ 'hysts/dlib_face_landmark_model',
59
+ 'shape_predictor_68_face_landmarks.dat',
60
+ use_auth_token=HF_TOKEN)
61
+ return dlib.shape_predictor(path)
62
+
63
+ def _load_encoder(self) -> nn.Module:
64
+ ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
65
+ 'models/encoder.pt',
66
+ use_auth_token=HF_TOKEN)
67
+ ckpt = torch.load(ckpt_path, map_location='cpu')
68
+ opts = ckpt['opts']
69
+ opts['device'] = self.device.type
70
+ opts['checkpoint_path'] = ckpt_path
71
+ opts = argparse.Namespace(**opts)
72
+ model = pSp(opts)
73
+ model.to(self.device)
74
+ model.eval()
75
+ return model
76
+
77
+ @staticmethod
78
+ def _create_transform() -> Callable:
79
+ transform = T.Compose([
80
+ T.Resize(256),
81
+ T.CenterCrop(256),
82
+ T.ToTensor(),
83
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
84
+ ])
85
+ return transform
86
+
87
+ def _load_generator(self, style_type: str) -> nn.Module:
88
+ model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
89
+ ckpt_path = huggingface_hub.hf_hub_download(
90
+ MODEL_REPO,
91
+ f'models/{style_type}/generator.pt',
92
+ use_auth_token=HF_TOKEN)
93
+ ckpt = torch.load(ckpt_path, map_location='cpu')
94
+ model.load_state_dict(ckpt['g_ema'])
95
+ model.to(self.device)
96
+ model.eval()
97
+ return model
98
+
99
+ @staticmethod
100
+ def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
101
+ if style_type in ['cartoon', 'caricature', 'anime']:
102
+ filename = 'refined_exstyle_code.npy'
103
+ else:
104
+ filename = 'exstyle_code.npy'
105
+ path = huggingface_hub.hf_hub_download(
106
+ MODEL_REPO,
107
+ f'models/{style_type}/{filename}',
108
+ use_auth_token=HF_TOKEN)
109
+ exstyles = np.load(path, allow_pickle=True).item()
110
+ return exstyles
111
+
112
+ def detect_and_align_face(self, image) -> np.ndarray:
113
+ image = align_face(filepath=image.name, predictor=self.landmark_model)
114
+ return image
115
+
116
+ @staticmethod
117
+ def denormalize(tensor: torch.Tensor) -> torch.Tensor:
118
+ return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
119
+
120
+ def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
121
+ tensor = self.denormalize(tensor)
122
+ return tensor.cpu().numpy().transpose(1, 2, 0)
123
+
124
+ @torch.inference_mode()
125
+ def reconstruct_face(self,
126
+ image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
127
+ image = PIL.Image.fromarray(image)
128
+ input_data = self.transform(image).unsqueeze(0).to(self.device)
129
+ img_rec, instyle = self.encoder(input_data,
130
+ randomize_noise=False,
131
+ return_latents=True,
132
+ z_plus_latent=True,
133
+ return_z_plus_latent=True,
134
+ resize=False)
135
+ img_rec = torch.clamp(img_rec.detach(), -1, 1)
136
+ img_rec = self.postprocess(img_rec[0])
137
+ return img_rec, instyle
138
+
139
+ @torch.inference_mode()
140
+ def generate(self, style_type: str, style_id: int, structure_weight: float,
141
+ color_weight: float, structure_only: bool,
142
+ instyle: torch.Tensor) -> np.ndarray:
143
+ generator = self.generator_dict[style_type]
144
+ exstyles = self.exstyle_dict[style_type]
145
+
146
+ style_id = int(style_id)
147
+ stylename = list(exstyles.keys())[style_id]
148
+
149
+ latent = torch.tensor(exstyles[stylename]).to(self.device)
150
+ if structure_only:
151
+ latent[0, 7:18] = instyle[0, 7:18]
152
+ exstyle = generator.generator.style(
153
+ latent.reshape(latent.shape[0] * latent.shape[1],
154
+ latent.shape[2])).reshape(latent.shape)
155
+
156
+ img_gen, _ = generator([instyle],
157
+ exstyle,
158
+ z_plus_latent=True,
159
+ truncation=0.7,
160
+ truncation_latent=0,
161
+ use_res=True,
162
+ interp_weights=[structure_weight] * 7 +
163
+ [color_weight] * 11)
164
+ img_gen = torch.clamp(img_gen.detach(), -1, 1)
165
+ img_gen = self.postprocess(img_gen[0])
166
+ return img_gen
style.css ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+ img#overview {
5
+ max-width: 800px;
6
+ max-height: 600px;
7
+ display: block;
8
+ margin: auto;
9
+ }
10
+ img#style-image {
11
+ max-width: 1000px;
12
+ max-height: 600px;
13
+ }
14
+ img#visitor-badge {
15
+ display: block;
16
+ margin: auto;
17
+ }