Gustavo Belfort commited on
Commit
7b31381
1 Parent(s): a7b2edd

update interface

Browse files
Files changed (3) hide show
  1. interface.py +88 -32
  2. interface_old.py +70 -0
  3. interface_projector.py +0 -126
interface.py CHANGED
@@ -5,13 +5,18 @@ import gradio as gr
5
  import numpy as np
6
  import torch
7
  import pickle
 
8
  import types
9
 
 
 
10
  from huggingface_hub import hf_hub_url, cached_download
11
 
12
- # with open('../models/gamma500/network-snapshot-010000.pkl', 'rb') as f:
 
 
13
  with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
14
- G = pickle.load(f)['G_ema']# torch.nn.Module
15
 
16
  device = torch.device("cpu")
17
  if torch.cuda.is_available():
@@ -35,36 +40,87 @@ else:
35
  G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
36
 
37
 
38
- def generate(num_images, interpolate):
39
- if interpolate:
40
- z1 = torch.randn([1, G.z_dim])# latent codes
41
- z2 = torch.randn([1, G.z_dim])# latent codes
42
- zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
43
- else:
44
- zs = torch.randn([num_images, G.z_dim])# latent codes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  with torch.no_grad():
46
- zs = zs.to(device)
47
- img = G(zs, None, force_fp32=True, noise_mode='const')
48
- img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
49
- return img.cpu().numpy()
50
-
51
- def greet(num_images, interpolate):
52
- img = generate(round(num_images), interpolate)
53
- imgs = list(img)
54
- if len(imgs) == 1:
55
- return imgs[0]
56
- grid_len = int(np.ceil(np.sqrt(len(imgs)))) * 2
57
- grid_height = int(np.ceil(len(imgs) / grid_len))
58
- grid = np.zeros((grid_height * imgs[0].shape[0], grid_len * imgs[0].shape[1], 3), dtype=np.uint8)
59
- for i, img in enumerate(imgs):
60
- y = (i // grid_len) * img.shape[0]
61
- x = (i % grid_len) * img.shape[1]
62
- grid[y:y+img.shape[0], x:x+img.shape[1], :] = img
63
- return grid
64
 
65
 
66
- iface = gr.Interface(fn=greet, inputs=[
67
- gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=9, step=1),
68
- gr.inputs.Checkbox(default=False, label="Interpolate")
69
- ], outputs="image")
70
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  import torch
7
  import pickle
8
+ import PIL.Image
9
  import types
10
 
11
+ from projector import project, imageio, _MODELS
12
+
13
  from huggingface_hub import hf_hub_url, cached_download
14
 
15
+ # with open("../models/gamma500/network-snapshot-010000.pkl", "rb") as f:
16
+ # with open("../models/gamma400/network-snapshot-010600.pkl", "rb") as f:
17
+ # with open("../models/gamma400/network-snapshot-019600.pkl", "rb") as f:
18
  with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
19
+ G = pickle.load(f)["G_ema"] # torch.nn.Module
20
 
21
  device = torch.device("cpu")
22
  if torch.cuda.is_available():
 
40
  G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
41
 
42
 
43
+ def generate(
44
+ target_image_upload,
45
+ # target_image_webcam,
46
+ num_steps,
47
+ seed,
48
+ learning_rate,
49
+ model_name,
50
+ normalize_for_clip,
51
+ loss_type,
52
+ regularize_noise_weight,
53
+ initial_noise_factor,
54
+ ):
55
+ seed = round(seed)
56
+ np.random.seed(seed)
57
+ torch.manual_seed(seed)
58
+ target_image = target_image_upload
59
+ # if target_image is None:
60
+ # target_image = target_image_webcam
61
+ num_steps = round(num_steps)
62
+ print(type(target_image))
63
+ print(target_image.dtype)
64
+ print(target_image.max())
65
+ print(target_image.min())
66
+ print(target_image.shape)
67
+ target_pil = PIL.Image.fromarray(target_image).convert("RGB")
68
+ w, h = target_pil.size
69
+ s = min(w, h)
70
+ target_pil = target_pil.crop(
71
+ ((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)
72
+ )
73
+ target_pil = target_pil.resize(
74
+ (G.img_resolution, G.img_resolution), PIL.Image.LANCZOS
75
+ )
76
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
77
+ target_image = torch.from_numpy(target_uint8.transpose([2, 0, 1])).to(device)
78
+ projected_w_steps = project(
79
+ G,
80
+ target=target_image,
81
+ num_steps=num_steps,
82
+ device=device,
83
+ verbose=True,
84
+ initial_learning_rate=learning_rate,
85
+ model_name=model_name,
86
+ normalize_for_clip=normalize_for_clip,
87
+ loss_type=loss_type,
88
+ regularize_noise_weight=regularize_noise_weight,
89
+ initial_noise_factor=initial_noise_factor,
90
+ )
91
  with torch.no_grad():
92
+ video = imageio.get_writer(f'proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
93
+ for w in projected_w_steps:
94
+ synth_image = G.synthesis(w.to(device).unsqueeze(0), noise_mode="const")
95
+ synth_image = (synth_image + 1) * (255 / 2)
96
+ synth_image = (
97
+ synth_image.permute(0, 2, 3, 1)
98
+ .clamp(0, 255)
99
+ .to(torch.uint8)[0]
100
+ .cpu()
101
+ .numpy()
102
+ )
103
+ video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
104
+ video.close()
105
+ return synth_image, "proj.mp4"
 
 
 
 
106
 
107
 
108
+ iface = gr.Interface(
109
+ fn=generate,
110
+ inputs=[
111
+ gr.inputs.Image(source="upload", optional=True),
112
+ # gr.inputs.Image(source="webcam", optional=True),
113
+ gr.inputs.Number(default=250, label="steps"),
114
+ gr.inputs.Number(default=69420, label="seed"),
115
+ gr.inputs.Number(default=0.05, label="learning_rate"),
116
+ gr.inputs.Dropdown(default='RN50', label="model_name", choices=['vgg16', *_MODELS.keys()]),
117
+ gr.inputs.Checkbox(default=True, label="normalize_for_clip"),
118
+ gr.inputs.Dropdown(
119
+ default="l2", label="loss_type", choices=["l2", "l1", "cosine"]
120
+ ),
121
+ gr.inputs.Number(default=1e5, label="regularize_noise_weight"),
122
+ gr.inputs.Number(default=0.05, label="initial_noise_factor"),
123
+ ],
124
+ outputs=["image", "video"],
125
+ )
126
+ iface.launch(inbrowser=True)
interface_old.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import gradio as gr
4
+
5
+ import numpy as np
6
+ import torch
7
+ import pickle
8
+ import types
9
+
10
+ from huggingface_hub import hf_hub_url, cached_download
11
+
12
+ # with open('../models/gamma500/network-snapshot-010000.pkl', 'rb') as f:
13
+ with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
14
+ G = pickle.load(f)['G_ema']# torch.nn.Module
15
+
16
+ device = torch.device("cpu")
17
+ if torch.cuda.is_available():
18
+ device = torch.device("cuda")
19
+ G = G.to(device)
20
+ else:
21
+ _old_forward = G.forward
22
+
23
+ def _new_forward(self, *args, **kwargs):
24
+ kwargs["force_fp32"] = True
25
+ return _old_forward(*args, **kwargs)
26
+
27
+ G.forward = types.MethodType(_new_forward, G)
28
+
29
+ _old_synthesis_forward = G.synthesis.forward
30
+
31
+ def _new_synthesis_forward(self, *args, **kwargs):
32
+ kwargs["force_fp32"] = True
33
+ return _old_synthesis_forward(*args, **kwargs)
34
+
35
+ G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
36
+
37
+
38
+ def generate(num_images, interpolate):
39
+ if interpolate:
40
+ z1 = torch.randn([1, G.z_dim])# latent codes
41
+ z2 = torch.randn([1, G.z_dim])# latent codes
42
+ zs = torch.cat([z1 + (z2 - z1) * i / (num_images-1) for i in range(num_images)], 0)
43
+ else:
44
+ zs = torch.randn([num_images, G.z_dim])# latent codes
45
+ with torch.no_grad():
46
+ zs = zs.to(device)
47
+ img = G(zs, None, force_fp32=True, noise_mode='const')
48
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
49
+ return img.cpu().numpy()
50
+
51
+ def greet(num_images, interpolate):
52
+ img = generate(round(num_images), interpolate)
53
+ imgs = list(img)
54
+ if len(imgs) == 1:
55
+ return imgs[0]
56
+ grid_len = int(np.ceil(np.sqrt(len(imgs)))) * 2
57
+ grid_height = int(np.ceil(len(imgs) / grid_len))
58
+ grid = np.zeros((grid_height * imgs[0].shape[0], grid_len * imgs[0].shape[1], 3), dtype=np.uint8)
59
+ for i, img in enumerate(imgs):
60
+ y = (i // grid_len) * img.shape[0]
61
+ x = (i % grid_len) * img.shape[1]
62
+ grid[y:y+img.shape[0], x:x+img.shape[1], :] = img
63
+ return grid
64
+
65
+
66
+ iface = gr.Interface(fn=greet, inputs=[
67
+ gr.inputs.Slider(default=1, label="Num Images", minimum=1, maximum=9, step=1),
68
+ gr.inputs.Checkbox(default=False, label="Interpolate")
69
+ ], outputs="image")
70
+ iface.launch()
interface_projector.py DELETED
@@ -1,126 +0,0 @@
1
- #!/usr/bin/env python3
2
-
3
- import gradio as gr
4
-
5
- import numpy as np
6
- import torch
7
- import pickle
8
- import PIL.Image
9
- import types
10
-
11
- from projector import project, imageio, _MODELS
12
-
13
- from huggingface_hub import hf_hub_url, cached_download
14
-
15
- # with open("../models/gamma500/network-snapshot-010000.pkl", "rb") as f:
16
- # with open("../models/gamma400/network-snapshot-010600.pkl", "rb") as f:
17
- # with open("../models/gamma400/network-snapshot-019600.pkl", "rb") as f:
18
- with open(cached_download(hf_hub_url('ykilcher/apes', 'gamma500/network-snapshot-010000.pkl')), 'rb') as f:
19
- G = pickle.load(f)["G_ema"] # torch.nn.Module
20
-
21
- device = torch.device("cpu")
22
- if torch.cuda.is_available():
23
- device = torch.device("cuda")
24
- G = G.to(device)
25
- else:
26
- _old_forward = G.forward
27
-
28
- def _new_forward(self, *args, **kwargs):
29
- kwargs["force_fp32"] = True
30
- return _old_forward(*args, **kwargs)
31
-
32
- G.forward = types.MethodType(_new_forward, G)
33
-
34
- _old_synthesis_forward = G.synthesis.forward
35
-
36
- def _new_synthesis_forward(self, *args, **kwargs):
37
- kwargs["force_fp32"] = True
38
- return _old_synthesis_forward(*args, **kwargs)
39
-
40
- G.synthesis.forward = types.MethodType(_new_synthesis_forward, G.synthesis)
41
-
42
-
43
- def generate(
44
- target_image_upload,
45
- # target_image_webcam,
46
- num_steps,
47
- seed,
48
- learning_rate,
49
- model_name,
50
- normalize_for_clip,
51
- loss_type,
52
- regularize_noise_weight,
53
- initial_noise_factor,
54
- ):
55
- seed = round(seed)
56
- np.random.seed(seed)
57
- torch.manual_seed(seed)
58
- target_image = target_image_upload
59
- # if target_image is None:
60
- # target_image = target_image_webcam
61
- num_steps = round(num_steps)
62
- print(type(target_image))
63
- print(target_image.dtype)
64
- print(target_image.max())
65
- print(target_image.min())
66
- print(target_image.shape)
67
- target_pil = PIL.Image.fromarray(target_image).convert("RGB")
68
- w, h = target_pil.size
69
- s = min(w, h)
70
- target_pil = target_pil.crop(
71
- ((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)
72
- )
73
- target_pil = target_pil.resize(
74
- (G.img_resolution, G.img_resolution), PIL.Image.LANCZOS
75
- )
76
- target_uint8 = np.array(target_pil, dtype=np.uint8)
77
- target_image = torch.from_numpy(target_uint8.transpose([2, 0, 1])).to(device)
78
- projected_w_steps = project(
79
- G,
80
- target=target_image,
81
- num_steps=num_steps,
82
- device=device,
83
- verbose=True,
84
- initial_learning_rate=learning_rate,
85
- model_name=model_name,
86
- normalize_for_clip=normalize_for_clip,
87
- loss_type=loss_type,
88
- regularize_noise_weight=regularize_noise_weight,
89
- initial_noise_factor=initial_noise_factor,
90
- )
91
- with torch.no_grad():
92
- video = imageio.get_writer(f'proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
93
- for w in projected_w_steps:
94
- synth_image = G.synthesis(w.to(device).unsqueeze(0), noise_mode="const")
95
- synth_image = (synth_image + 1) * (255 / 2)
96
- synth_image = (
97
- synth_image.permute(0, 2, 3, 1)
98
- .clamp(0, 255)
99
- .to(torch.uint8)[0]
100
- .cpu()
101
- .numpy()
102
- )
103
- video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
104
- video.close()
105
- return synth_image, "proj.mp4"
106
-
107
-
108
- iface = gr.Interface(
109
- fn=generate,
110
- inputs=[
111
- gr.inputs.Image(source="upload", optional=True),
112
- # gr.inputs.Image(source="webcam", optional=True),
113
- gr.inputs.Number(default=250, label="steps"),
114
- gr.inputs.Number(default=69420, label="seed"),
115
- gr.inputs.Number(default=0.05, label="learning_rate"),
116
- gr.inputs.Dropdown(default='RN50', label="model_name", choices=['vgg16', *_MODELS.keys()]),
117
- gr.inputs.Checkbox(default=True, label="normalize_for_clip"),
118
- gr.inputs.Dropdown(
119
- default="l2", label="loss_type", choices=["l2", "l1", "cosine"]
120
- ),
121
- gr.inputs.Number(default=1e5, label="regularize_noise_weight"),
122
- gr.inputs.Number(default=0.05, label="initial_noise_factor"),
123
- ],
124
- outputs=["image", "video"],
125
- )
126
- iface.launch(inbrowser=True)