Yannic Kilcher commited on
Commit
5c824be
·
1 Parent(s): 6f160b3

added interfaces for interpolation and projection

Browse files
Files changed (6) hide show
  1. .gitignore +1 -0
  2. README.md +5 -0
  3. interface.py +70 -0
  4. interface_projector.py +126 -0
  5. interpolate.py +10 -0
  6. projector.py +54 -5
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  __pycache__/
2
  .cache/
 
 
1
  __pycache__/
2
  .cache/
3
+ proj.mp4
README.md CHANGED
@@ -1,3 +1,8 @@
 
 
 
 
 
1
  ## StyleGAN2-ADA — Official PyTorch implementation
2
 
3
  ![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png)
 
1
+ ## Project repo for apes by ykilcher
2
+
3
+ Note: most of the code is taken from nvlabs/stylegan2-ada-pytroch (original readme below).
4
+ I added gradio interfaces and CLIP projection.
5
+
6
  ## StyleGAN2-ADA — Official PyTorch implementation
7
 
8
  ![Teaser image](./docs/stylegan2-ada-teaser-1024x252.png)
interface.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import gradio as gr
4
+
5
+ import numpy as np
6
+ import torch
7
+ import pickle
8
+ import types
9
+
10
+ from huggingface_hub import hf_hub_url, cached_download
11
+
12
+ # with open('../models/gamma500/network-snapshot-010000.pkl', 'rb') as f:
13
+ with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
14
+ G = pickle.load(f)['G_ema']# torch.nn.Module
15
+
16
+ device = torch.device("cpu")
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ G = G.to(device)
20
+ else:
21
+ _old_forward = G.forward
22
+
23
+ def _new_forward(self, *args, **kwargs):
24
+ kwargs["force_fp32"] = True
25
+ return _old_forward(self, *args, **kwargs)
26
+
27
+ G.forward = types.MethodType(_new_forward, G)
28
+
29
+ _old_synthesis_forward = G.synthesis.forward
30
+
31
+ def _new_synthesis_forward(self, *args, **kwargs):
32
+ kwargs["force_fp32"] = True
33
+ return _old_synthesis_forward(self, *args, **kwargs)
34
+
35
+ G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
36
+
37
+
38
+ def generate(num_images, interpolate):
39
+ if interpolate:
40
+ z1 = torch.randn([1, G.z_dim])# latent codes
41
+ z2 = torch.randn([1, G.z_dim])# latent codes
42
+ zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
43
+ else:
44
+ zs = torch.randn([num_images, G.z_dim])# latent codes
45
+ with torch.no_grad():
46
+ zs = zs.to(device)
47
+ img = G(zs, None, force_fp32=True, truncation_psi=1, noise_mode='const')
48
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
49
+ return img.cpu().numpy()
50
+
51
+ def greet(num_images, interpolate):
52
+ img = generate(round(num_images), interpolate)
53
+ imgs = list(img)
54
+ if len(imgs) == 1:
55
+ return imgs[0]
56
+ grid_len = int(np.ceil(np.sqrt(len(imgs)))) * 2
57
+ grid_height = int(np.ceil(len(imgs) / grid_len))
58
+ grid = np.zeros((grid_height * imgs[0].shape[0], grid_len * imgs[0].shape[1], 3), dtype=np.uint8)
59
+ for i, img in enumerate(imgs):
60
+ y = (i // grid_len) * img.shape[0]
61
+ x = (i % grid_len) * img.shape[1]
62
+ grid[y:y+img.shape[0], x:x+img.shape[1], :] = img
63
+ return grid
64
+
65
+
66
+ iface = gr.Interface(fn=greet, inputs=[
67
+ gr.inputs.Number(default=1, label="Num Images"),
68
+ gr.inputs.Checkbox(default=False, label="Interpolate")
69
+ ], outputs="image")
70
+ iface.launch()
interface_projector.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import gradio as gr
4
+
5
+ import numpy as np
6
+ import torch
7
+ import pickle
8
+ import PIL.Image
9
+ import types
10
+
11
+ from projector import project, imageio, _MODELS
12
+
13
+ from huggingface_hub import hf_hub_url, cached_download
14
+
15
+ # with open("../models/gamma500/network-snapshot-010000.pkl", "rb") as f:
16
+ # with open("../models/gamma400/network-snapshot-010600.pkl", "rb") as f:
17
+ # with open("../models/gamma400/network-snapshot-019600.pkl", "rb") as f:
18
+ with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
19
+ G = pickle.load(f)["G_ema"] # torch.nn.Module
20
+
21
+ device = torch.device("cpu")
22
+ if torch.cuda.is_available():
23
+ device = torch.device("cuda")
24
+ G = G.to(device)
25
+ else:
26
+ _old_forward = G.forward
27
+
28
+ def _new_forward(self, *args, **kwargs):
29
+ kwargs["force_fp32"] = True
30
+ return _old_forward(self, *args, **kwargs)
31
+
32
+ G.forward = types.MethodType(_new_forward, G)
33
+
34
+ _old_synthesis_forward = G.synthesis.forward
35
+
36
+ def _new_synthesis_forward(self, *args, **kwargs):
37
+ kwargs["force_fp32"] = True
38
+ return _old_synthesis_forward(self, *args, **kwargs)
39
+
40
+ G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
41
+
42
+
43
+ def generate(
44
+ target_image_upload,
45
+ # target_image_webcam,
46
+ num_steps,
47
+ seed,
48
+ learning_rate,
49
+ model_name,
50
+ normalize_for_clip,
51
+ loss_type,
52
+ regularize_noise_weight,
53
+ initial_noise_factor,
54
+ ):
55
+ seed = round(seed)
56
+ np.random.seed(seed)
57
+ torch.manual_seed(seed)
58
+ target_image = target_image_upload
59
+ # if target_image is None:
60
+ # target_image = target_image_webcam
61
+ num_steps = round(num_steps)
62
+ print(type(target_image))
63
+ print(target_image.dtype)
64
+ print(target_image.max())
65
+ print(target_image.min())
66
+ print(target_image.shape)
67
+ target_pil = PIL.Image.fromarray(target_image).convert("RGB")
68
+ w, h = target_pil.size
69
+ s = min(w, h)
70
+ target_pil = target_pil.crop(
71
+ ((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)
72
+ )
73
+ target_pil = target_pil.resize(
74
+ (G.img_resolution, G.img_resolution), PIL.Image.LANCZOS
75
+ )
76
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
77
+ target_image = torch.from_numpy(target_uint8.transpose([2, 0, 1])).to(device)
78
+ projected_w_steps = project(
79
+ G,
80
+ target=target_image,
81
+ num_steps=num_steps,
82
+ device=device,
83
+ verbose=True,
84
+ initial_learning_rate=learning_rate,
85
+ model_name=model_name,
86
+ normalize_for_clip=normalize_for_clip,
87
+ loss_type=loss_type,
88
+ regularize_noise_weight=regularize_noise_weight,
89
+ initial_noise_factor=initial_noise_factor,
90
+ )
91
+ with torch.no_grad():
92
+ video = imageio.get_writer(f'proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
93
+ for w in projected_w_steps:
94
+ synth_image = G.synthesis(w.to(device).unsqueeze(0), noise_mode="const")
95
+ synth_image = (synth_image + 1) * (255 / 2)
96
+ synth_image = (
97
+ synth_image.permute(0, 2, 3, 1)
98
+ .clamp(0, 255)
99
+ .to(torch.uint8)[0]
100
+ .cpu()
101
+ .numpy()
102
+ )
103
+ video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
104
+ video.close()
105
+ return synth_image, "proj.mp4"
106
+
107
+
108
+ iface = gr.Interface(
109
+ fn=generate,
110
+ inputs=[
111
+ gr.inputs.Image(source="upload", optional=True),
112
+ # gr.inputs.Image(source="webcam", optional=True),
113
+ gr.inputs.Number(default=250, label="steps"),
114
+ gr.inputs.Number(default=69420, label="seed"),
115
+ gr.inputs.Number(default=0.05, label="learning_rate"),
116
+ gr.inputs.Dropdown(default='RN50', label="model_name", choices=['vgg16', *_MODELS.keys()]),
117
+ gr.inputs.Checkbox(default=True, label="normalize_for_clip"),
118
+ gr.inputs.Dropdown(
119
+ default="l2", label="loss_type", choices=["l2", "l1", "cosine"]
120
+ ),
121
+ gr.inputs.Number(default=1e5, label="regularize_noise_weight"),
122
+ gr.inputs.Number(default=0.05, label="initial_noise_factor"),
123
+ ],
124
+ outputs=["image", "video"],
125
+ )
126
+ iface.launch(inbrowser=True)
interpolate.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import torch
4
+ import pickle
5
+
6
+ with open('../models/gamma500/network-snapshot-010000.pkl', 'rb') as f:
7
+ G = pickle.load(f)['G_ema']# torch.nn.Module
8
+ z = torch.randn([1, G.z_dim])# latent codes
9
+ c = None # class labels (not used in this example)
10
+ img = G(z, c, force_fp32=True) # NCHW, float32, dynamic range [-1, +1]
projector.py CHANGED
@@ -22,6 +22,18 @@ import torch.nn.functional as F
22
  import dnnlib
23
  import legacy
24
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def project(
26
  G,
27
  target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
@@ -35,6 +47,9 @@ def project(
35
  noise_ramp_length = 0.75,
36
  regularize_noise_weight = 1e5,
37
  verbose = False,
 
 
 
38
  device: torch.device
39
  ):
40
  assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
@@ -56,16 +71,38 @@ def project(
56
  # Setup noise inputs.
57
  noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
58
 
 
59
  # Load VGG16 feature detector.
60
  url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
 
 
 
 
 
 
61
  with dnnlib.util.open_url(url) as f:
62
  vgg16 = torch.jit.load(f).eval().to(device)
63
 
64
  # Features for target image.
65
  target_images = target.unsqueeze(0).to(device).to(torch.float32)
66
- if target_images.shape[2] > 256:
67
- target_images = F.interpolate(target_images, size=(256, 256), mode='area')
68
- target_features = vgg16(target_images, resize_images=False, return_lpips=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
71
  w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
@@ -98,8 +135,20 @@ def project(
98
  synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
99
 
100
  # Features for synth images.
101
- synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
102
- dist = (target_features - synth_features).square().sum()
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  # Noise regularization.
105
  reg_loss = 0.0
 
22
  import dnnlib
23
  import legacy
24
 
25
+ _MODELS = {
26
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
27
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
28
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
29
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
30
+ "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
31
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
32
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
33
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
34
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
35
+ }
36
+
37
  def project(
38
  G,
39
  target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
 
47
  noise_ramp_length = 0.75,
48
  regularize_noise_weight = 1e5,
49
  verbose = False,
50
+ model_name='vgg16',
51
+ loss_type='l2',
52
+ normalize_for_clip=True,
53
  device: torch.device
54
  ):
55
  assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
 
71
  # Setup noise inputs.
72
  noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
73
 
74
+ USE_CLIP = model_name != 'vgg16'
75
  # Load VGG16 feature detector.
76
  url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
77
+ if USE_CLIP:
78
+ # url = 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt'
79
+ # url = 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt'
80
+ # url = 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt'
81
+ # url = 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt'
82
+ url = _MODELS[model_name]
83
  with dnnlib.util.open_url(url) as f:
84
  vgg16 = torch.jit.load(f).eval().to(device)
85
 
86
  # Features for target image.
87
  target_images = target.unsqueeze(0).to(device).to(torch.float32)
88
+ if USE_CLIP:
89
+ image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device)[:, None, None]
90
+ image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device)[:, None, None]
91
+ # target_images = F.interpolate(target_images, size=(224, 224), mode='area')
92
+ target_images = F.interpolate(target_images, size=(vgg16.input_resolution.item(), vgg16.input_resolution.item()), mode='area')
93
+ print("target_images.shape:", target_images.shape)
94
+ def _encode_image(image):
95
+ image = image / 255.
96
+ # image = torch.sigmoid(image)
97
+ if normalize_for_clip:
98
+ image = (image - image_mean) / image_std
99
+ return vgg16.encode_image(image)
100
+ target_features = _encode_image(target_images.clamp(0, 255))
101
+ target_features = target_features.detach()
102
+ else:
103
+ if target_images.shape[2] > 256:
104
+ target_images = F.interpolate(target_images, size=(256, 256), mode='area')
105
+ target_features = vgg16(target_images, resize_images=False, return_lpips=True)
106
 
107
  w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
108
  w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
 
135
  synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
136
 
137
  # Features for synth images.
138
+ if USE_CLIP:
139
+ synth_images = F.interpolate(synth_images, size=(vgg16.input_resolution.item(), vgg16.input_resolution.item()), mode='area')
140
+ synth_features = _encode_image(synth_images)
141
+ if loss_type == 'cosine':
142
+ target_features_normalized = target_features / target_features.norm(dim=-1, keepdim=True).detach()
143
+ synth_features_normalized = synth_features / synth_features.norm(dim=-1, keepdim=True).detach()
144
+ dist = 1.0 - torch.sum(synth_features_normalized * target_features_normalized)
145
+ elif loss_type == 'l1':
146
+ dist = (target_features - synth_features).abs().sum()
147
+ else:
148
+ dist = (target_features - synth_features).square().sum()
149
+ else:
150
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
151
+ dist = (target_features - synth_features).square().sum()
152
 
153
  # Noise regularization.
154
  reg_loss = 0.0