amankishore commited on
Commit
10528ca
1 Parent(s): 7a11626

app with gradio

Browse files
Files changed (1) hide show
  1. app.py +187 -122
app.py CHANGED
@@ -1,5 +1,8 @@
1
  import numpy as np
 
 
2
  import torch
 
3
 
4
  from my.utils import tqdm
5
  from my.utils.seed import seed_everything
@@ -12,33 +15,12 @@ from run_nerf import VoxConfig
12
  from voxnerf.utils import every
13
  from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
14
 
15
- from run_sjc import render_one_view
16
 
17
- device_glb = torch.device("cuda")
18
-
19
- @torch.no_grad()
20
- def evaluate(score_model, vox, poser):
21
- H, W = poser.H, poser.W
22
- vox.eval()
23
- K, poses = poser.sample_test(100)
24
-
25
- aabb = vox.aabb.T.cpu().numpy()
26
- vox = vox.to(device_glb)
27
-
28
- num_imgs = len(poses)
29
 
30
- for i in (pbar := tqdm(range(num_imgs))):
31
-
32
- pose = poses[i]
33
- y, depth = render_one_view(vox, aabb, H, W, K, pose)
34
- if isinstance(score_model, StableDiffusion):
35
- y = score_model.decode(y)
36
- pane, img, depth = vis_routine(y, depth)
37
-
38
- # metric.put_artifact(
39
- # "view_seq", ".mp4",
40
- # lambda fn: stitch_vis(fn, read_stats(metric.output_dir, "view")[1])
41
- # )
42
 
43
  def vis_routine(y, depth):
44
  pane = nerf_vis(y, depth, final_H=256)
@@ -46,110 +28,193 @@ def vis_routine(y, depth):
46
  depth = depth.cpu().numpy()
47
  return pane, im, depth
48
 
49
-
50
- if __name__ == "__main__":
51
- # cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
52
- pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
53
- poser = pose.make()
54
- sd_model = SD(variant='v1', v2_highres=False, prompt='A high quality photo of a delicious burger', scale=100.0, precision='autocast')
55
- model = sd_model.make()
56
- vox = VoxConfig(
57
- model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
58
- blend_bg_texture=True, bg_texture_hw=4,
59
- bbox_len=1.0)
60
- vox = vox.make()
61
-
62
- lr = 0.05
63
- n_steps = 10000
64
- emptiness_scale = 10
65
- emptiness_weight = 10000
66
- emptiness_step = 0.5
67
- emptiness_multiplier = 20.0
68
- depth_weight = 0
69
- var_red = True
70
-
71
- assert model.samps_centered()
72
- _, target_H, target_W = model.data_shape()
73
- bs = 1
74
- aabb = vox.aabb.T.cpu().numpy()
75
- vox = vox.to(device_glb)
76
- opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
77
-
78
- H, W = poser.H, poser.W
79
- Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
80
-
81
- ts = model.us[30:-10]
82
-
83
- same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
84
-
85
- with tqdm(total=n_steps) as pbar:
86
- for i in range(n_steps):
87
-
88
- p = f"{prompt_prefixes[i]} {model.prompt}"
89
- score_conds = model.prompts_emb([p])
90
-
91
- y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
92
-
93
- if isinstance(model, StableDiffusion):
94
- pass
95
- else:
96
- y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
97
-
98
- opt.zero_grad()
99
-
100
- with torch.no_grad():
101
- chosen_σs = np.random.choice(ts, bs, replace=False)
102
- chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
103
- chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
104
- # chosen_σs = us[i]
105
-
106
- noise = torch.randn(bs, *y.shape[1:], device=model.device)
107
-
108
- zs = y + chosen_σs * noise
109
- Ds = model.denoise(zs, chosen_σs, **score_conds)
110
-
111
- if var_red:
112
- grad = (Ds - y) / chosen_σs
113
  else:
114
- grad = (Ds - zs) / chosen_σs
115
-
116
- grad = grad.mean(0, keepdim=True)
117
-
118
- y.backward(-grad, retain_graph=True)
119
 
120
- if depth_weight > 0:
121
- center_depth = depth[7:-7, 7:-7]
122
- border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
123
- center_depth_mean = center_depth.mean()
124
- depth_diff = center_depth_mean - border_depth_mean
125
- depth_loss = - torch.log(depth_diff + 1e-12)
126
- depth_loss = depth_weight * depth_loss
127
- depth_loss.backward(retain_graph=True)
128
 
129
- emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
130
- emptiness_loss = emptiness_weight * emptiness_loss
131
- if emptiness_step * n_steps <= i:
132
- emptiness_loss *= emptiness_multiplier
133
- emptiness_loss.backward()
134
-
135
- opt.step()
136
 
 
137
 
138
- # metric.put_scalars(**tsr_stats(y))
 
139
 
140
- if every(pbar, percent=1):
141
- with torch.no_grad():
142
- if isinstance(model, StableDiffusion):
143
- y = model.decode(y)
144
- pane, img, depth = vis_routine(y, depth)
145
 
146
- # TODO: Output pane, img and depth to Gradio
 
 
147
 
148
- pbar.update()
149
- pbar.set_description(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # TODO: Save Checkpoint
152
- ckpt = vox.state_dict()
153
- # evaluate(model, vox, poser)
154
 
155
- # TODO: Add code to stitch together the images and save them to a video
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ import time
3
+ from pathlib import Path
4
  import torch
5
+ import imageio
6
 
7
  from my.utils import tqdm
8
  from my.utils.seed import seed_everything
 
15
  from voxnerf.utils import every
16
  from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
17
 
18
+ from run_sjc import render_one_view, tsr_stats
19
 
20
+ import gradio as gr
21
+ import gc
 
 
 
 
 
 
 
 
 
 
22
 
23
+ device_glb = torch.device("cuda")
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def vis_routine(y, depth):
26
  pane = nerf_vis(y, depth, final_H=256)
 
28
  depth = depth.cpu().numpy()
29
  return pane, im, depth
30
 
31
+ with gr.Blocks(css=".gradio-container {max-width: 512px; margin: auto;}") as demo:
32
+ # title
33
+ gr.Markdown('[Score Jacobian Chaining](https://github.com/pals-ttic/sjc) Lifting Pretrained 2D Diffusion Models for 3D Generation')
34
+
35
+ # inputs
36
+ prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger")
37
+ iters = gr.Slider(label="Iters", minimum=1000, maximum=20000, value=10000, step=100)
38
+ seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
39
+ button = gr.Button('Generate')
40
+
41
+ # outputs
42
+ image = gr.Image(label="image", visible=True)
43
+ depth = gr.Image(label="depth", visible=True)
44
+ video = gr.Video(label="video", visible=False)
45
+ logs = gr.Textbox(label="logging")
46
+
47
+ def submit(prompt, iters, seed):
48
+ start_t = time.time()
49
+ seed_everything(seed)
50
+ # cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
51
+ pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
52
+ poser = pose.make()
53
+ sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast')
54
+ model = sd_model.make()
55
+ vox = VoxConfig(
56
+ model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
57
+ blend_bg_texture=True, bg_texture_hw=4,
58
+ bbox_len=1.0)
59
+ vox = vox.make()
60
+
61
+ lr = 0.05
62
+ n_steps = iters
63
+ emptiness_scale = 10
64
+ emptiness_weight = 10000
65
+ emptiness_step = 0.5
66
+ emptiness_multiplier = 20.0
67
+ depth_weight = 0
68
+ var_red = True
69
+
70
+ assert model.samps_centered()
71
+ _, target_H, target_W = model.data_shape()
72
+ bs = 1
73
+ aabb = vox.aabb.T.cpu().numpy()
74
+ vox = vox.to(device_glb)
75
+ opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
76
+
77
+ H, W = poser.H, poser.W
78
+ Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
79
+
80
+ ts = model.us[30:-10]
81
+
82
+ same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
83
+
84
+ with tqdm(total=n_steps) as pbar:
85
+ for i in range(n_steps):
86
+
87
+ p = f"{prompt_prefixes[i]} {model.prompt}"
88
+ score_conds = model.prompts_emb([p])
89
+
90
+ y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
91
+
92
+ if isinstance(model, StableDiffusion):
93
+ pass
 
94
  else:
95
+ y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
 
 
 
 
96
 
97
+ opt.zero_grad()
 
 
 
 
 
 
 
98
 
99
+ with torch.no_grad():
100
+ chosen_σs = np.random.choice(ts, bs, replace=False)
101
+ chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
102
+ chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
103
+ # chosen_σs = us[i]
 
 
104
 
105
+ noise = torch.randn(bs, *y.shape[1:], device=model.device)
106
 
107
+ zs = y + chosen_σs * noise
108
+ Ds = model.denoise(zs, chosen_σs, **score_conds)
109
 
110
+ if var_red:
111
+ grad = (Ds - y) / chosen_σs
112
+ else:
113
+ grad = (Ds - zs) / chosen_σs
 
114
 
115
+ grad = grad.mean(0, keepdim=True)
116
+
117
+ y.backward(-grad, retain_graph=True)
118
 
119
+ if depth_weight > 0:
120
+ center_depth = depth[7:-7, 7:-7]
121
+ border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
122
+ center_depth_mean = center_depth.mean()
123
+ depth_diff = center_depth_mean - border_depth_mean
124
+ depth_loss = - torch.log(depth_diff + 1e-12)
125
+ depth_loss = depth_weight * depth_loss
126
+ depth_loss.backward(retain_graph=True)
127
+
128
+ emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
129
+ emptiness_loss = emptiness_weight * emptiness_loss
130
+ if emptiness_step * n_steps <= i:
131
+ emptiness_loss *= emptiness_multiplier
132
+ emptiness_loss.backward()
133
+
134
+ opt.step()
135
+
136
+
137
+ # metric.put_scalars()
138
+
139
+ if every(pbar, percent=1):
140
+ with torch.no_grad():
141
+ if isinstance(model, StableDiffusion):
142
+ y = model.decode(y)
143
+ pane, img, depth = vis_routine(y, depth)
144
+
145
+ # TODO: Output pane, img and depth to Gradio
146
+
147
+ pbar.update()
148
+ pbar.set_description(p)
149
+
150
+ yield {
151
+ image: gr.update(value=img, visible=True),
152
+ depth: gr.update(value=depth, visible=True),
153
+ video: gr.update(visible=False),
154
+ logs: str(tsr_stats(y)),
155
+ }
156
+
157
+ # TODO: Save Checkpoint
158
+ ckpt = vox.state_dict()
159
+ H, W = poser.H, poser.W
160
+ vox.eval()
161
+ K, poses = poser.sample_test(100)
162
+
163
+ aabb = vox.aabb.T.cpu().numpy()
164
+ vox = vox.to(device_glb)
165
+
166
+ num_imgs = len(poses)
167
+
168
+ for i in (pbar := tqdm(range(num_imgs))):
169
+
170
+ pose = poses[i]
171
+ y, depth = render_one_view(vox, aabb, H, W, K, pose)
172
+ if isinstance(model, StableDiffusion):
173
+ y = model.decode(y)
174
+ pane, img, depth = vis_routine(y, depth)
175
+
176
+ # Save img to output
177
+ img.save(f"output/{i}.png")
178
+
179
+ yield {
180
+ image: gr.update(value=img, visible=True),
181
+ depth: gr.update(value=depth, visible=True),
182
+ video: gr.update(visible=False),
183
+ logs: str(tsr_stats(y)),
184
+ }
185
+
186
+ output_video = "view_seq.mp4"
187
+
188
+ def export_movie(seqs, fname, fps=30):
189
+ fname = Path(fname)
190
+ if fname.suffix == "":
191
+ fname = fname.with_suffix(".mp4")
192
+ writer = imageio.get_writer(fname, fps=fps)
193
+ for img in seqs:
194
+ writer.append_data(img)
195
+ writer.close()
196
+
197
+ def stitch_vis(save_fn, img_fnames, fps=10):
198
+ figs = [imageio.imread(fn) for fn in img_fnames]
199
+ export_movie(figs, save_fn, fps)
200
+
201
+ stitch_vis(output_video, [f"output/{i}.png" for i in range(num_imgs)])
202
 
203
+ end_t = time.time()
 
 
204
 
205
+ yield {
206
+ image: gr.update(value=img, visible=False),
207
+ depth: gr.update(value=depth, visible=False),
208
+ video: gr.update(value=output_video, visible=True),
209
+ logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
210
+ }
211
+
212
+ button.click(
213
+ submit,
214
+ [prompt, iters, seed],
215
+ [image, depth, video, logs]
216
+ )
217
+
218
+ # concurrency_count: only allow ONE running progress, else GPU will OOM.
219
+ demo.queue(concurrency_count=1)
220
+ demo.launch()