turn-the-cam-anonymous commited on
Commit
03e871c
1 Parent(s): dc1ad90

taming directory

Browse files
Files changed (4) hide show
  1. CLIP +1 -0
  2. app.py +543 -99
  3. gradio_new.py +0 -663
  4. taming-transformers +1 -0
CLIP ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit a9b1bf5920416aaeaec965c25dd9e8f98c864f16
app.py CHANGED
@@ -1,25 +1,55 @@
 
 
 
 
 
 
 
1
  import math
2
  import fire
3
  import gradio as gr
 
 
4
  import numpy as np
 
 
5
  import rich
 
 
6
  import torch
7
  from contextlib import nullcontext
 
8
  from einops import rearrange
9
  from functools import partial
10
  from ldm.models.diffusion.ddim import DDIMSampler
11
- from ldm.util import load_and_preprocess, instantiate_from_config
 
12
  from omegaconf import OmegaConf
13
  from PIL import Image
14
  from rich import print
 
15
  from torch import autocast
16
  from torchvision import transforms
17
 
18
 
19
- _SHOW_INTERMEDIATE = True
 
 
20
  _GPU_INDEX = 0
21
  # _GPU_INDEX = 2
22
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  def load_model_from_config(config, ckpt, device, verbose=False):
25
  print(f'Loading model from {ckpt}')
@@ -81,27 +111,180 @@ def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_sample
81
  return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
82
 
83
 
84
- def main(
85
- model,
86
- device,
87
- input_im,
88
- preprocess=True,
89
- x=0.,
90
- y=0.,
91
- z=0.,
92
- scale=3.0,
93
- n_samples=4,
94
- ddim_steps=50,
95
- ddim_eta=1.0,
96
- precision='fp32',
97
- h=256,
98
- w=256,
99
- ):
100
- # input_im[input_im == [0., 0., 0.]] = [1., 1., 1., 1.]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  print('old input_im:', input_im.size)
 
102
 
103
  if preprocess:
104
- input_im = load_and_preprocess(input_im)
105
  input_im = (input_im / 255.0).astype(np.float32)
106
  # (H, W, 3) array in [0, 1].
107
 
@@ -109,8 +292,8 @@ def main(
109
  input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS)
110
  input_im = np.asarray(input_im, dtype=np.float32) / 255.0
111
  # (H, W, 4) array in [0, 1].
112
-
113
- # old method: very important, thresholding background
114
  # input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1., 1.]
115
 
116
  # new method: apply correct method of compositing to avoid sudden transitions / thresholding
@@ -118,102 +301,363 @@ def main(
118
  alpha = input_im[:, :, 3:4]
119
  white_im = np.ones_like(input_im)
120
  input_im = alpha * input_im + (1.0 - alpha) * white_im
121
-
122
  input_im = input_im[:, :, 0:3]
123
  # (H, W, 3) array in [0, 1].
124
 
125
- print('new input_im:', input_im.shape, input_im.dtype, input_im.min(), input_im.max())
126
- show_in_im = Image.fromarray((input_im * 255).astype(np.uint8))
127
 
128
- input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
129
- input_im = input_im * 2 - 1
130
- input_im = transforms.functional.resize(input_im, [h, w])
131
 
132
- sampler = DDIMSampler(model)
133
- x_samples_ddim = sample_model(input_im, model, sampler, precision, h, w,
134
- ddim_steps, n_samples, scale, ddim_eta, x, y, z)
135
 
136
- output_ims = []
137
- for x_sample in x_samples_ddim:
138
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
139
- output_ims.append(Image.fromarray(x_sample.astype(np.uint8)))
 
 
 
 
140
 
141
- if _SHOW_INTERMEDIATE:
142
- return (output_ims, show_in_im)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  else:
144
- return output_ims
145
 
 
146
 
147
- description = '''
148
- Generate novel viewpoints of an object depicted in one input image using a fine-tuned version of Stable Diffusion.
149
- '''
 
 
 
 
150
 
151
- article = '''
152
- ## How to use this?
153
- TBD
154
- ## How does this work?
155
- TBD
156
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
 
159
  def run_demo(
160
- device_idx=_GPU_INDEX,
161
- ckpt='last.ckpt',
162
- config='configs/sd-objaverse-finetune-c_concat-256.yaml',
163
- ):
 
 
 
 
 
164
 
165
  device = f'cuda:{device_idx}'
166
  config = OmegaConf.load(config)
167
- model = load_model_from_config(config, ckpt, device=device)
168
-
169
- inputs = [
170
- gr.Image(type='pil', image_mode='RGBA'), # shape=[512, 512]
171
- gr.Checkbox(True, label='Preprocess image (remove background and center)',
172
- info='If enabled, the uploaded image will be preprocessed to remove the background and center the object by cropping and/or padding as necessary. '
173
- 'If disabled, the image will be used as-is, *BUT* a fully transparent or white background is required.'),
174
- # gr.Number(label='polar (between axis z+)'),
175
- # gr.Number(label='azimuth (between axis x+)'),
176
- # gr.Number(label='z (distance from center)'),
177
- gr.Slider(-90, 90, value=0, step=5, label='Polar angle (vertical rotation in degrees)',
178
- info='Positive values move the camera down, while negative values move the camera up.'),
179
- gr.Slider(-90, 90, value=0, step=5, label='Azimuth angle (horizontal rotation in degrees)',
180
- info='Positive values move the camera right, while negative values move the camera left.'),
181
- gr.Slider(-2, 2, value=0, step=0.5, label='Radius (distance from center)',
182
- info='Positive values move the camera further away, while negative values move the camera closer.'),
183
- gr.Slider(0, 30, value=3, step=1, label='cfg scale'),
184
- gr.Slider(1, 8, value=4, step=1, label='Number of samples to generate'),
185
- gr.Slider(5, 200, value=100, step=5, label='Number of steps'),
186
- ]
187
- output = [gr.Gallery(label='Generated images from specified new viewpoint')]
188
- output[0].style(grid=2)
189
-
190
- if _SHOW_INTERMEDIATE:
191
- output += [gr.Image(type='pil', image_mode='RGB', label='Preprocessed input image')]
192
-
193
- fn_with_model = partial(main, model, device)
194
- fn_with_model.__name__ = 'fn_with_model'
195
-
196
- examples = [
197
- # ['assets/zero-shot/bear.png', 0, 0, 0, 3, 4, 100],
198
- # ['assets/zero-shot/car.png', 0, 0, 0, 3, 4, 100],
199
- # ['assets/zero-shot/elephant.png', 0, 0, 0, 3, 4, 100],
200
- # ['assets/zero-shot/pikachu.png', 0, 0, 0, 3, 4, 100],
201
- # ['assets/zero-shot/spyro.png', 0, 0, 0, 3, 4, 100],
202
- # ['assets/zero-shot/taxi.png', 0, 0, 0, 3, 4, 100],
203
- ]
204
-
205
- demo = gr.Interface(
206
- fn=fn_with_model,
207
- title='Demo for Zero-Shot Control of Camera Viewpoints within a Single Image',
208
- description=description,
209
- article=article,
210
- inputs=inputs,
211
- outputs=output,
212
- examples=examples,
213
- allow_flagging='never',
214
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
  demo.launch(enable_queue=True, share=True)
216
 
217
 
218
  if __name__ == '__main__':
 
219
  fire.Fire(run_demo)
 
1
+ '''
2
+ conda activate zero123
3
+ cd stable-diffusion
4
+ python gradio_new.py 0
5
+ '''
6
+
7
+ import diffusers # 0.12.1
8
  import math
9
  import fire
10
  import gradio as gr
11
+ import lovely_numpy
12
+ import lovely_tensors
13
  import numpy as np
14
+ import plotly.express as px
15
+ import plotly.graph_objects as go
16
  import rich
17
+ import sys
18
+ import time
19
  import torch
20
  from contextlib import nullcontext
21
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
22
  from einops import rearrange
23
  from functools import partial
24
  from ldm.models.diffusion.ddim import DDIMSampler
25
+ from ldm.util import create_carvekit_interface, load_and_preprocess, instantiate_from_config
26
+ from lovely_numpy import lo
27
  from omegaconf import OmegaConf
28
  from PIL import Image
29
  from rich import print
30
+ from transformers import AutoFeatureExtractor #, CLIPImageProcessor
31
  from torch import autocast
32
  from torchvision import transforms
33
 
34
 
35
+ _SHOW_DESC = True
36
+ _SHOW_INTERMEDIATE = False
37
+ # _SHOW_INTERMEDIATE = True
38
  _GPU_INDEX = 0
39
  # _GPU_INDEX = 2
40
 
41
+ # _TITLE = 'Zero-Shot Control of Camera Viewpoints within a Single Image'
42
+ _TITLE = 'Zero-1-to-3: Zero-shot One Image to 3D Object'
43
+
44
+ # This demo allows you to generate novel viewpoints of an object depicted in an input image using a fine-tuned version of Stable Diffusion.
45
+ _DESCRIPTION = '''
46
+ This demo allows you to control camera rotation and thereby generate novel viewpoints of an object within a single image.
47
+ It is based on Stable Diffusion. Check out our [project webpage](https://zero123.cs.columbia.edu/) and [paper](https://arxiv.org/) if you want to learn more about the method!
48
+ Note that this model is not intended for images of humans or faces, and is unlikely to work well for them.
49
+ '''
50
+
51
+ _ARTICLE = 'See uses.md'
52
+
53
 
54
  def load_model_from_config(config, ckpt, device, verbose=False):
55
  print(f'Loading model from {ckpt}')
 
111
  return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
112
 
113
 
114
+ class CameraVisualizer:
115
+ def __init__(self, gradio_plot):
116
+ self._gradio_plot = gradio_plot
117
+ self._fig = None
118
+ self._polar = 0.0
119
+ self._azimuth = 0.0
120
+ self._radius = 0.0
121
+ self._raw_image = None
122
+ self._8bit_image = None
123
+ self._image_colorscale = None
124
+
125
+ def polar_change(self, value):
126
+ self._polar = value
127
+ # return self.update_figure()
128
+
129
+ def azimuth_change(self, value):
130
+ self._azimuth = value
131
+ # return self.update_figure()
132
+
133
+ def radius_change(self, value):
134
+ self._radius = value
135
+ # return self.update_figure()
136
+
137
+ def encode_image(self, raw_image):
138
+ '''
139
+ :param raw_image (H, W, 3) array of uint8 in [0, 255].
140
+ '''
141
+ # https://stackoverflow.com/questions/60685749/python-plotly-how-to-add-an-image-to-a-3d-scatter-plot
142
+
143
+ dum_img = Image.fromarray(np.ones((3, 3, 3), dtype='uint8')).convert('P', palette='WEB')
144
+ idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3))
145
+
146
+ self._raw_image = raw_image
147
+ self._8bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None)
148
+ # self._8bit_image = Image.fromarray(raw_image.clip(0, 254)).convert(
149
+ # 'P', palette='WEB', dither=None)
150
+ self._image_colorscale = [
151
+ [i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)]
152
+
153
+ # return self.update_figure()
154
+
155
+ def update_figure(self):
156
+ fig = go.Figure()
157
+
158
+ if self._raw_image is not None:
159
+ (H, W, C) = self._raw_image.shape
160
+
161
+ x = np.zeros((H, W))
162
+ (y, z) = np.meshgrid(np.linspace(-1.0, 1.0, W), np.linspace(1.0, -1.0, H) * H / W)
163
+ print('x:', lo(x))
164
+ print('y:', lo(y))
165
+ print('z:', lo(z))
166
+
167
+ fig.add_trace(go.Surface(
168
+ x=x, y=y, z=z,
169
+ surfacecolor=self._8bit_image,
170
+ cmin=0,
171
+ cmax=255,
172
+ colorscale=self._image_colorscale,
173
+ showscale=False,
174
+ lighting_diffuse=1.0,
175
+ lighting_ambient=1.0,
176
+ lighting_fresnel=1.0,
177
+ lighting_roughness=1.0,
178
+ lighting_specular=0.3))
179
+
180
+ scene_bounds = 3.5
181
+ base_radius = 2.5
182
+ zoom_scale = 1.5 # Note that input radius offset is in [-0.5, 0.5].
183
+ fov_deg = 50.0
184
+ edges = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4), (4, 1)]
185
+
186
+ input_cone = calc_cam_cone_pts_3d(
187
+ 0.0, 0.0, base_radius, fov_deg) # (5, 3).
188
+ output_cone = calc_cam_cone_pts_3d(
189
+ self._polar, self._azimuth, base_radius + self._radius * zoom_scale, fov_deg) # (5, 3).
190
+ # print('input_cone:', lo(input_cone).v)
191
+ # print('output_cone:', lo(output_cone).v)
192
+
193
+ for (cone, clr, legend) in [(input_cone, 'green', 'Input view'),
194
+ (output_cone, 'blue', 'Target view')]:
195
+
196
+ for (i, edge) in enumerate(edges):
197
+ (x1, x2) = (cone[edge[0], 0], cone[edge[1], 0])
198
+ (y1, y2) = (cone[edge[0], 1], cone[edge[1], 1])
199
+ (z1, z2) = (cone[edge[0], 2], cone[edge[1], 2])
200
+ fig.add_trace(go.Scatter3d(
201
+ x=[x1, x2], y=[y1, y2], z=[z1, z2], mode='lines',
202
+ line=dict(color=clr, width=3),
203
+ name=legend, showlegend=(i == 0)))
204
+ # text=(legend if i == 0 else None),
205
+ # textposition='bottom center'))
206
+ # hoverinfo='text',
207
+ # hovertext='hovertext'))
208
+
209
+ # Add label.
210
+ if cone[0, 2] <= base_radius / 2.0:
211
+ fig.add_trace(go.Scatter3d(
212
+ x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] - 0.05], showlegend=False,
213
+ mode='text', text=legend, textposition='bottom center'))
214
+ else:
215
+ fig.add_trace(go.Scatter3d(
216
+ x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] + 0.05], showlegend=False,
217
+ mode='text', text=legend, textposition='top center'))
218
+
219
+ # look at center of scene
220
+ fig.update_layout(
221
+ # width=640,
222
+ # height=480,
223
+ # height=400,
224
+ height=360,
225
+ autosize=True,
226
+ hovermode=False,
227
+ margin=go.layout.Margin(l=0, r=0, b=0, t=0),
228
+ showlegend=True,
229
+ legend=dict(
230
+ yanchor='bottom',
231
+ y=0.01,
232
+ xanchor='right',
233
+ x=0.99,
234
+ ),
235
+ scene=dict(
236
+ aspectmode='manual',
237
+ aspectratio=dict(x=1, y=1, z=1.0),
238
+ camera=dict(
239
+ eye=dict(x=base_radius - 1.6, y=0.0, z=0.6),
240
+ center=dict(x=0.0, y=0.0, z=0.0),
241
+ up=dict(x=0.0, y=0.0, z=1.0)),
242
+ xaxis_title='',
243
+ yaxis_title='',
244
+ zaxis_title='',
245
+ xaxis=dict(
246
+ range=[-scene_bounds, scene_bounds],
247
+ showticklabels=False,
248
+ showgrid=True,
249
+ zeroline=False,
250
+ showbackground=True,
251
+ showspikes=False,
252
+ showline=False,
253
+ ticks=''),
254
+ yaxis=dict(
255
+ range=[-scene_bounds, scene_bounds],
256
+ showticklabels=False,
257
+ showgrid=True,
258
+ zeroline=False,
259
+ showbackground=True,
260
+ showspikes=False,
261
+ showline=False,
262
+ ticks=''),
263
+ zaxis=dict(
264
+ range=[-scene_bounds, scene_bounds],
265
+ showticklabels=False,
266
+ showgrid=True,
267
+ zeroline=False,
268
+ showbackground=True,
269
+ showspikes=False,
270
+ showline=False,
271
+ ticks='')))
272
+
273
+ self._fig = fig
274
+ return fig
275
+
276
+
277
+ def preprocess_image(models, input_im, preprocess):
278
+ '''
279
+ :param input_im (PIL Image).
280
+ :return input_im (H, W, 3) array in [0, 1].
281
+ '''
282
+
283
  print('old input_im:', input_im.size)
284
+ start_time = time.time()
285
 
286
  if preprocess:
287
+ input_im = load_and_preprocess(models['carvekit'], input_im)
288
  input_im = (input_im / 255.0).astype(np.float32)
289
  # (H, W, 3) array in [0, 1].
290
 
 
292
  input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS)
293
  input_im = np.asarray(input_im, dtype=np.float32) / 255.0
294
  # (H, W, 4) array in [0, 1].
295
+
296
+ # old method: thresholding background, very important
297
  # input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1., 1.]
298
 
299
  # new method: apply correct method of compositing to avoid sudden transitions / thresholding
 
301
  alpha = input_im[:, :, 3:4]
302
  white_im = np.ones_like(input_im)
303
  input_im = alpha * input_im + (1.0 - alpha) * white_im
304
+
305
  input_im = input_im[:, :, 0:3]
306
  # (H, W, 3) array in [0, 1].
307
 
308
+ print(f'Infer foreground mask (preprocess_image) took {time.time() - start_time:.3f}s.')
309
+ print('new input_im:', lo(input_im))
310
 
311
+ return input_im
 
 
312
 
 
 
 
313
 
314
+ def main_run(models, device, cam_vis, return_what,
315
+ x=0.0, y=0.0, z=0.0,
316
+ raw_im=None, preprocess=True,
317
+ scale=3.0, n_samples=4, ddim_steps=50, ddim_eta=1.0,
318
+ precision='fp32', h=256, w=256):
319
+ '''
320
+ :param raw_im (PIL Image).
321
+ '''
322
 
323
+ safety_checker_input = models['clip_fe'](raw_im, return_tensors='pt').to(device)
324
+ (image, has_nsfw_concept) = models['nsfw'](
325
+ images=np.ones((1, 3)), clip_input=safety_checker_input.pixel_values)
326
+ print('has_nsfw_concept:', has_nsfw_concept)
327
+ if np.any(has_nsfw_concept):
328
+ print('NSFW content detected.')
329
+ to_return = [None] * 10
330
+ description = ('### <span style="color:red"> Unfortunately, '
331
+ 'potential NSFW content was detected, '
332
+ 'which is not supported by our model. '
333
+ 'Please try again with a different image. </span>')
334
+ if 'angles' in return_what:
335
+ to_return[0] = 0.0
336
+ to_return[1] = 0.0
337
+ to_return[2] = 0.0
338
+ to_return[3] = description
339
+ else:
340
+ to_return[0] = description
341
+ return to_return
342
+
343
  else:
344
+ print('Safety check passed.')
345
 
346
+ input_im = preprocess_image(models, raw_im, preprocess)
347
 
348
+ # if np.random.rand() < 0.3:
349
+ # description = ('Unfortunately, a human, a face, or potential NSFW content was detected, '
350
+ # 'which is not supported by our model.')
351
+ # if vis_only:
352
+ # return (None, None, description)
353
+ # else:
354
+ # return (None, None, None, description)
355
 
356
+ show_in_im1 = (input_im * 255.0).astype(np.uint8)
357
+ show_in_im2 = Image.fromarray(show_in_im1)
358
+
359
+ if 'rand' in return_what:
360
+ x = int(np.round(np.arcsin(np.random.uniform(-1.0, 1.0)) * 160.0 / np.pi)) # [-80, 80].
361
+ y = int(np.round(np.random.uniform(-150.0, 150.0)))
362
+ z = 0.0
363
+
364
+ cam_vis.polar_change(x)
365
+ cam_vis.azimuth_change(y)
366
+ cam_vis.radius_change(z)
367
+ cam_vis.encode_image(show_in_im1)
368
+ new_fig = cam_vis.update_figure()
369
+
370
+ if 'vis' in return_what:
371
+ description = ('The viewpoints are visualized on the top right. '
372
+ 'Click Run Generation to update the results on the bottom right.')
373
+
374
+ if 'angles' in return_what:
375
+ return (x, y, z, description, new_fig, show_in_im2)
376
+ else:
377
+ return (description, new_fig, show_in_im2)
378
+
379
+ elif 'gen' in return_what:
380
+ input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
381
+ input_im = input_im * 2 - 1
382
+ input_im = transforms.functional.resize(input_im, [h, w])
383
+
384
+ sampler = DDIMSampler(models['turncam'])
385
+ # used_x = -x # NOTE: Polar makes more sense in Basile's opinion this way!
386
+ used_x = x # NOTE: Set this way for consistency.
387
+ x_samples_ddim = sample_model(input_im, models['turncam'], sampler, precision, h, w,
388
+ ddim_steps, n_samples, scale, ddim_eta, used_x, y, z)
389
+
390
+ output_ims = []
391
+ for x_sample in x_samples_ddim:
392
+ x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
393
+ output_ims.append(Image.fromarray(x_sample.astype(np.uint8)))
394
+
395
+ description = None
396
+
397
+ if 'angles' in return_what:
398
+ return (x, y, z, description, new_fig, show_in_im2, output_ims)
399
+ else:
400
+ return (description, new_fig, show_in_im2, output_ims)
401
+
402
+
403
+ def calc_cam_cone_pts_3d(polar_deg, azimuth_deg, radius_m, fov_deg):
404
+ '''
405
+ :param polar_deg (float).
406
+ :param azimuth_deg (float).
407
+ :param radius_m (float).
408
+ :param fov_deg (float).
409
+ :return (5, 3) array of float with (x, y, z).
410
+ '''
411
+ polar_rad = np.deg2rad(polar_deg)
412
+ azimuth_rad = np.deg2rad(azimuth_deg)
413
+ fov_rad = np.deg2rad(fov_deg)
414
+ polar_rad = -polar_rad # NOTE: Inverse of how used_x relates to x.
415
+
416
+ # Camera pose center:
417
+ cam_x = radius_m * np.cos(azimuth_rad) * np.cos(polar_rad)
418
+ cam_y = radius_m * np.sin(azimuth_rad) * np.cos(polar_rad)
419
+ cam_z = radius_m * np.sin(polar_rad)
420
+
421
+ # Obtain four corners of camera frustum, assuming it is looking at origin.
422
+ # First, obtain camera extrinsics (rotation matrix only):
423
+ camera_R = np.array([[np.cos(azimuth_rad) * np.cos(polar_rad),
424
+ -np.sin(azimuth_rad),
425
+ -np.cos(azimuth_rad) * np.sin(polar_rad)],
426
+ [np.sin(azimuth_rad) * np.cos(polar_rad),
427
+ np.cos(azimuth_rad),
428
+ -np.sin(azimuth_rad) * np.sin(polar_rad)],
429
+ [np.sin(polar_rad),
430
+ 0.0,
431
+ np.cos(polar_rad)]])
432
+ # print('camera_R:', lo(camera_R).v)
433
+
434
+ # Multiply by corners in camera space to obtain go to space:
435
+ corn1 = [-1.0, np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)]
436
+ corn2 = [-1.0, -np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)]
437
+ corn3 = [-1.0, -np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)]
438
+ corn4 = [-1.0, np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)]
439
+ corn1 = np.dot(camera_R, corn1)
440
+ corn2 = np.dot(camera_R, corn2)
441
+ corn3 = np.dot(camera_R, corn3)
442
+ corn4 = np.dot(camera_R, corn4)
443
+
444
+ # Now attach as offset to actual 3D camera position:
445
+ corn1 = np.array(corn1) / np.linalg.norm(corn1, ord=2)
446
+ corn_x1 = cam_x + corn1[0]
447
+ corn_y1 = cam_y + corn1[1]
448
+ corn_z1 = cam_z + corn1[2]
449
+ corn2 = np.array(corn2) / np.linalg.norm(corn2, ord=2)
450
+ corn_x2 = cam_x + corn2[0]
451
+ corn_y2 = cam_y + corn2[1]
452
+ corn_z2 = cam_z + corn2[2]
453
+ corn3 = np.array(corn3) / np.linalg.norm(corn3, ord=2)
454
+ corn_x3 = cam_x + corn3[0]
455
+ corn_y3 = cam_y + corn3[1]
456
+ corn_z3 = cam_z + corn3[2]
457
+ corn4 = np.array(corn4) / np.linalg.norm(corn4, ord=2)
458
+ corn_x4 = cam_x + corn4[0]
459
+ corn_y4 = cam_y + corn4[1]
460
+ corn_z4 = cam_z + corn4[2]
461
+
462
+ xs = [cam_x, corn_x1, corn_x2, corn_x3, corn_x4]
463
+ ys = [cam_y, corn_y1, corn_y2, corn_y3, corn_y4]
464
+ zs = [cam_z, corn_z1, corn_z2, corn_z3, corn_z4]
465
+
466
+ return np.array([xs, ys, zs]).T
467
 
468
 
469
  def run_demo(
470
+ device_idx=_GPU_INDEX,
471
+ ckpt='105000.ckpt',
472
+ config='configs/sd-objaverse-finetune-c_concat-256.yaml'):
473
+
474
+ print('sys.argv:', sys.argv)
475
+ if len(sys.argv) > 1:
476
+ print('old device_idx:', device_idx)
477
+ device_idx = int(sys.argv[1])
478
+ print('new device_idx:', device_idx)
479
 
480
  device = f'cuda:{device_idx}'
481
  config = OmegaConf.load(config)
482
+
483
+ # Instantiate all models beforehand for efficiency.
484
+ models = dict()
485
+ print('Instantiating LatentDiffusion...')
486
+ models['turncam'] = load_model_from_config(config, ckpt, device=device)
487
+ print('Instantiating Carvekit HiInterface...')
488
+ models['carvekit'] = create_carvekit_interface()
489
+ print('Instantiating StableDiffusionSafetyChecker...')
490
+ models['nsfw'] = StableDiffusionSafetyChecker.from_pretrained(
491
+ 'CompVis/stable-diffusion-safety-checker').to(device)
492
+ print('Instantiating AutoFeatureExtractor...')
493
+ models['clip_fe'] = AutoFeatureExtractor.from_pretrained(
494
+ 'CompVis/stable-diffusion-safety-checker')
495
+
496
+ # Reduce NSFW false positives.
497
+ # NOTE: At the time of writing, and for diffusers 0.12.1, the default parameters are:
498
+ # models['nsfw'].concept_embeds_weights:
499
+ # [0.1800, 0.1900, 0.2060, 0.2100, 0.1950, 0.1900, 0.1940, 0.1900, 0.1900, 0.2200, 0.1900,
500
+ # 0.1900, 0.1950, 0.1984, 0.2100, 0.2140, 0.2000].
501
+ # models['nsfw'].special_care_embeds_weights:
502
+ # [0.1950, 0.2000, 0.2200].
503
+ # We multiply all by some factor > 1 to make them less likely to be triggered.
504
+ models['nsfw'].concept_embeds_weights *= 1.07
505
+ models['nsfw'].special_care_embeds_weights *= 1.07
506
+
507
+ with open('instructions.md', 'r') as f:
508
+ article = f.read()
509
+
510
+ # Compose demo layout & data flow.
511
+ demo = gr.Blocks(title=_TITLE)
512
+
513
+ with demo:
514
+ gr.Markdown('# ' + _TITLE)
515
+ gr.Markdown(_DESCRIPTION)
516
+
517
+ with gr.Row():
518
+ with gr.Column(scale=0.9, variant='panel'):
519
+
520
+ image_block = gr.Image(type='pil', image_mode='RGBA',
521
+ label='Input image of single object')
522
+ preprocess_chk = gr.Checkbox(
523
+ True, label='Preprocess image automatically (remove background and recenter object)')
524
+ # info='If enabled, the uploaded image will be preprocessed to remove the background and recenter the object by cropping and/or padding as necessary. '
525
+ # 'If disabled, the image will be used as-is, *BUT* a fully transparent or white background is required.'),
526
+
527
+ gr.Markdown('*Try camera position presets:*')
528
+ with gr.Row():
529
+ left_btn = gr.Button('View from the Left', variant='primary')
530
+ above_btn = gr.Button('View from Above', variant='primary')
531
+ right_btn = gr.Button('View from the Right', variant='primary')
532
+ with gr.Row():
533
+ random_btn = gr.Button('Random Rotation', variant='primary')
534
+ below_btn = gr.Button('View from Below', variant='primary')
535
+ behind_btn = gr.Button('View from Behind', variant='primary')
536
+
537
+ gr.Markdown('*Control camera position manually:*')
538
+ polar_slider = gr.Slider(
539
+ -90, 90, value=0, step=5, label='Polar angle (vertical rotation in degrees)')
540
+ # info='Positive values move the camera down, while negative values move the camera up.')
541
+ azimuth_slider = gr.Slider(
542
+ -180, 180, value=0, step=5, label='Azimuth angle (horizontal rotation in degrees)')
543
+ # info='Positive values move the camera right, while negative values move the camera left.')
544
+ radius_slider = gr.Slider(
545
+ -0.5, 0.5, value=0.0, step=0.1, label='Zoom (relative distance from center)')
546
+ # info='Positive values move the camera further away, while negative values move the camera closer.')
547
+
548
+ samples_slider = gr.Slider(1, 8, value=4, step=1,
549
+ label='Number of samples to generate')
550
+
551
+ with gr.Accordion('Advanced options', open=False):
552
+ scale_slider = gr.Slider(0, 30, value=3, step=1,
553
+ label='Diffusion guidance scale')
554
+ steps_slider = gr.Slider(5, 200, value=75, step=5,
555
+ label='Number of diffusion inference steps')
556
+
557
+ with gr.Row():
558
+ vis_btn = gr.Button('Visualize Angles', variant='secondary')
559
+ run_btn = gr.Button('Run Generation', variant='primary')
560
+
561
+ desc_output = gr.Markdown('The results will appear on the right.', visible=_SHOW_DESC)
562
+
563
+ with gr.Column(scale=1.1, variant='panel'):
564
+
565
+ vis_output = gr.Plot(
566
+ label='Relationship between input (green) and output (blue) camera poses')
567
+
568
+ gen_output = gr.Gallery(label='Generated images from specified new viewpoint')
569
+ gen_output.style(grid=2)
570
+
571
+ preproc_output = gr.Image(type='pil', image_mode='RGB',
572
+ label='Preprocessed input image', visible=_SHOW_INTERMEDIATE)
573
+
574
+ gr.Markdown(article)
575
+
576
+ # NOTE: I am forced to update vis_output for these preset buttons,
577
+ # because otherwise the gradio plot always resets the plotly 3D viewpoint for some reason,
578
+ # which might confuse the user into thinking that the plot has been updated too.
579
+
580
+ # OLD 1:
581
+ # left_btn.click(fn=lambda: [0.0, -90.0], #, 0.0],
582
+ # inputs=[], outputs=[polar_slider, azimuth_slider]), #], radius_slider])
583
+ # above_btn.click(fn=lambda: [90.0, 0.0], #, 0.0],
584
+ # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
585
+ # right_btn.click(fn=lambda: [0.0, 90.0], #, 0.0],
586
+ # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
587
+ # random_btn.click(fn=lambda: [int(np.round(np.random.uniform(-60.0, 60.0))),
588
+ # int(np.round(np.random.uniform(-150.0, 150.0)))], #, 0.0],
589
+ # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
590
+ # below_btn.click(fn=lambda: [-90.0, 0.0], #, 0.0],
591
+ # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
592
+ # behind_btn.click(fn=lambda: [0.0, 180.0], #, 0.0],
593
+ # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
594
+
595
+ # OLD 2:
596
+ # preset_text = ('You have selected a preset target camera view. '
597
+ # 'Now click Run Generation to update the results!')
598
+
599
+ # left_btn.click(fn=lambda: [0.0, -90.0, None, preset_text],
600
+ # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
601
+ # above_btn.click(fn=lambda: [90.0, 0.0, None, preset_text],
602
+ # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
603
+ # right_btn.click(fn=lambda: [0.0, 90.0, None, preset_text],
604
+ # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
605
+ # random_btn.click(fn=lambda: [int(np.round(np.random.uniform(-60.0, 60.0))),
606
+ # int(np.round(np.random.uniform(-150.0, 150.0))),
607
+ # None, preset_text],
608
+ # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
609
+ # below_btn.click(fn=lambda: [-90.0, 0.0, None, preset_text],
610
+ # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
611
+ # behind_btn.click(fn=lambda: [0.0, 180.0, None, preset_text],
612
+ # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
613
+
614
+ # OLD 3 (does not work at all):
615
+ # def a():
616
+ # polar_slider.value = 77.7
617
+ # polar_slider.postprocess(77.7)
618
+ # print('testa')
619
+ # left_btn.click(fn=a)
620
+
621
+ cam_vis = CameraVisualizer(vis_output)
622
+
623
+ vis_btn.click(fn=partial(main_run, models, device, cam_vis, 'vis'),
624
+ inputs=[polar_slider, azimuth_slider, radius_slider,
625
+ image_block, preprocess_chk],
626
+ outputs=[desc_output, vis_output, preproc_output])
627
+
628
+ run_btn.click(fn=partial(main_run, models, device, cam_vis, 'gen'),
629
+ inputs=[polar_slider, azimuth_slider, radius_slider,
630
+ image_block, preprocess_chk,
631
+ scale_slider, samples_slider, steps_slider],
632
+ outputs=[desc_output, vis_output, preproc_output, gen_output])
633
+
634
+ # NEW:
635
+ preset_inputs = [image_block, preprocess_chk,
636
+ scale_slider, samples_slider, steps_slider]
637
+ preset_outputs = [polar_slider, azimuth_slider, radius_slider,
638
+ desc_output, vis_output, preproc_output, gen_output]
639
+ left_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
640
+ 0.0, -90.0, 0.0),
641
+ inputs=preset_inputs, outputs=preset_outputs)
642
+ above_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
643
+ -90.0, 0.0, 0.0),
644
+ inputs=preset_inputs, outputs=preset_outputs)
645
+ right_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
646
+ 0.0, 90.0, 0.0),
647
+ inputs=preset_inputs, outputs=preset_outputs)
648
+ random_btn.click(fn=partial(main_run, models, device, cam_vis, 'rand_angles_gen',
649
+ -1.0, -1.0, -1.0),
650
+ inputs=preset_inputs, outputs=preset_outputs)
651
+ below_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
652
+ 90.0, 0.0, 0.0),
653
+ inputs=preset_inputs, outputs=preset_outputs)
654
+ behind_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
655
+ 0.0, 180.0, 0.0),
656
+ inputs=preset_inputs, outputs=preset_outputs)
657
+
658
  demo.launch(enable_queue=True, share=True)
659
 
660
 
661
  if __name__ == '__main__':
662
+
663
  fire.Fire(run_demo)
gradio_new.py DELETED
@@ -1,663 +0,0 @@
1
- '''
2
- conda activate zero123
3
- cd stable-diffusion
4
- python gradio_new.py 0
5
- '''
6
-
7
- import diffusers # 0.12.1
8
- import math
9
- import fire
10
- import gradio as gr
11
- import lovely_numpy
12
- import lovely_tensors
13
- import numpy as np
14
- import plotly.express as px
15
- import plotly.graph_objects as go
16
- import rich
17
- import sys
18
- import time
19
- import torch
20
- from contextlib import nullcontext
21
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
22
- from einops import rearrange
23
- from functools import partial
24
- from ldm.models.diffusion.ddim import DDIMSampler
25
- from ldm.util import create_carvekit_interface, load_and_preprocess, instantiate_from_config
26
- from lovely_numpy import lo
27
- from omegaconf import OmegaConf
28
- from PIL import Image
29
- from rich import print
30
- from transformers import AutoFeatureExtractor #, CLIPImageProcessor
31
- from torch import autocast
32
- from torchvision import transforms
33
-
34
-
35
- _SHOW_DESC = True
36
- _SHOW_INTERMEDIATE = False
37
- # _SHOW_INTERMEDIATE = True
38
- _GPU_INDEX = 0
39
- # _GPU_INDEX = 2
40
-
41
- # _TITLE = 'Zero-Shot Control of Camera Viewpoints within a Single Image'
42
- _TITLE = 'Zero-1-to-3: Zero-shot One Image to 3D Object'
43
-
44
- # This demo allows you to generate novel viewpoints of an object depicted in an input image using a fine-tuned version of Stable Diffusion.
45
- _DESCRIPTION = '''
46
- This demo allows you to control camera rotation and thereby generate novel viewpoints of an object within a single image.
47
- It is based on Stable Diffusion. Check out our [project webpage](https://zero123.cs.columbia.edu/) and [paper](https://arxiv.org/) if you want to learn more about the method!
48
- Note that this model is not intended for images of humans or faces, and is unlikely to work well for them.
49
- '''
50
-
51
- _ARTICLE = 'See uses.md'
52
-
53
-
54
- def load_model_from_config(config, ckpt, device, verbose=False):
55
- print(f'Loading model from {ckpt}')
56
- pl_sd = torch.load(ckpt, map_location=device)
57
- if 'global_step' in pl_sd:
58
- print(f'Global Step: {pl_sd["global_step"]}')
59
- sd = pl_sd['state_dict']
60
- model = instantiate_from_config(config.model)
61
- m, u = model.load_state_dict(sd, strict=False)
62
- if len(m) > 0 and verbose:
63
- print('missing keys:')
64
- print(m)
65
- if len(u) > 0 and verbose:
66
- print('unexpected keys:')
67
- print(u)
68
-
69
- model.to(device)
70
- model.eval()
71
- return model
72
-
73
-
74
- @torch.no_grad()
75
- def sample_model(input_im, model, sampler, precision, h, w, ddim_steps, n_samples, scale,
76
- ddim_eta, x, y, z):
77
- precision_scope = autocast if precision == 'autocast' else nullcontext
78
- with precision_scope('cuda'):
79
- with model.ema_scope():
80
- c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
81
- T = torch.tensor([math.radians(x), math.sin(
82
- math.radians(y)), math.cos(math.radians(y)), z])
83
- T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device)
84
- c = torch.cat([c, T], dim=-1)
85
- c = model.cc_projection(c)
86
- cond = {}
87
- cond['c_crossattn'] = [c]
88
- c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
89
- cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach()
90
- .repeat(n_samples, 1, 1, 1)]
91
- if scale != 1.0:
92
- uc = {}
93
- uc['c_concat'] = [torch.zeros(n_samples, 4, h // 8, w // 8).to(c.device)]
94
- uc['c_crossattn'] = [torch.zeros_like(c).to(c.device)]
95
- else:
96
- uc = None
97
-
98
- shape = [4, h // 8, w // 8]
99
- samples_ddim, _ = sampler.sample(S=ddim_steps,
100
- conditioning=cond,
101
- batch_size=n_samples,
102
- shape=shape,
103
- verbose=False,
104
- unconditional_guidance_scale=scale,
105
- unconditional_conditioning=uc,
106
- eta=ddim_eta,
107
- x_T=None)
108
- print(samples_ddim.shape)
109
- # samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
110
- x_samples_ddim = model.decode_first_stage(samples_ddim)
111
- return torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
112
-
113
-
114
- class CameraVisualizer:
115
- def __init__(self, gradio_plot):
116
- self._gradio_plot = gradio_plot
117
- self._fig = None
118
- self._polar = 0.0
119
- self._azimuth = 0.0
120
- self._radius = 0.0
121
- self._raw_image = None
122
- self._8bit_image = None
123
- self._image_colorscale = None
124
-
125
- def polar_change(self, value):
126
- self._polar = value
127
- # return self.update_figure()
128
-
129
- def azimuth_change(self, value):
130
- self._azimuth = value
131
- # return self.update_figure()
132
-
133
- def radius_change(self, value):
134
- self._radius = value
135
- # return self.update_figure()
136
-
137
- def encode_image(self, raw_image):
138
- '''
139
- :param raw_image (H, W, 3) array of uint8 in [0, 255].
140
- '''
141
- # https://stackoverflow.com/questions/60685749/python-plotly-how-to-add-an-image-to-a-3d-scatter-plot
142
-
143
- dum_img = Image.fromarray(np.ones((3, 3, 3), dtype='uint8')).convert('P', palette='WEB')
144
- idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3))
145
-
146
- self._raw_image = raw_image
147
- self._8bit_image = Image.fromarray(raw_image).convert('P', palette='WEB', dither=None)
148
- # self._8bit_image = Image.fromarray(raw_image.clip(0, 254)).convert(
149
- # 'P', palette='WEB', dither=None)
150
- self._image_colorscale = [
151
- [i / 255.0, 'rgb({}, {}, {})'.format(*rgb)] for i, rgb in enumerate(idx_to_color)]
152
-
153
- # return self.update_figure()
154
-
155
- def update_figure(self):
156
- fig = go.Figure()
157
-
158
- if self._raw_image is not None:
159
- (H, W, C) = self._raw_image.shape
160
-
161
- x = np.zeros((H, W))
162
- (y, z) = np.meshgrid(np.linspace(-1.0, 1.0, W), np.linspace(1.0, -1.0, H) * H / W)
163
- print('x:', lo(x))
164
- print('y:', lo(y))
165
- print('z:', lo(z))
166
-
167
- fig.add_trace(go.Surface(
168
- x=x, y=y, z=z,
169
- surfacecolor=self._8bit_image,
170
- cmin=0,
171
- cmax=255,
172
- colorscale=self._image_colorscale,
173
- showscale=False,
174
- lighting_diffuse=1.0,
175
- lighting_ambient=1.0,
176
- lighting_fresnel=1.0,
177
- lighting_roughness=1.0,
178
- lighting_specular=0.3))
179
-
180
- scene_bounds = 3.5
181
- base_radius = 2.5
182
- zoom_scale = 1.5 # Note that input radius offset is in [-0.5, 0.5].
183
- fov_deg = 50.0
184
- edges = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4), (4, 1)]
185
-
186
- input_cone = calc_cam_cone_pts_3d(
187
- 0.0, 0.0, base_radius, fov_deg) # (5, 3).
188
- output_cone = calc_cam_cone_pts_3d(
189
- self._polar, self._azimuth, base_radius + self._radius * zoom_scale, fov_deg) # (5, 3).
190
- # print('input_cone:', lo(input_cone).v)
191
- # print('output_cone:', lo(output_cone).v)
192
-
193
- for (cone, clr, legend) in [(input_cone, 'green', 'Input view'),
194
- (output_cone, 'blue', 'Target view')]:
195
-
196
- for (i, edge) in enumerate(edges):
197
- (x1, x2) = (cone[edge[0], 0], cone[edge[1], 0])
198
- (y1, y2) = (cone[edge[0], 1], cone[edge[1], 1])
199
- (z1, z2) = (cone[edge[0], 2], cone[edge[1], 2])
200
- fig.add_trace(go.Scatter3d(
201
- x=[x1, x2], y=[y1, y2], z=[z1, z2], mode='lines',
202
- line=dict(color=clr, width=3),
203
- name=legend, showlegend=(i == 0)))
204
- # text=(legend if i == 0 else None),
205
- # textposition='bottom center'))
206
- # hoverinfo='text',
207
- # hovertext='hovertext'))
208
-
209
- # Add label.
210
- if cone[0, 2] <= base_radius / 2.0:
211
- fig.add_trace(go.Scatter3d(
212
- x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] - 0.05], showlegend=False,
213
- mode='text', text=legend, textposition='bottom center'))
214
- else:
215
- fig.add_trace(go.Scatter3d(
216
- x=[cone[0, 0]], y=[cone[0, 1]], z=[cone[0, 2] + 0.05], showlegend=False,
217
- mode='text', text=legend, textposition='top center'))
218
-
219
- # look at center of scene
220
- fig.update_layout(
221
- # width=640,
222
- # height=480,
223
- # height=400,
224
- height=360,
225
- autosize=True,
226
- hovermode=False,
227
- margin=go.layout.Margin(l=0, r=0, b=0, t=0),
228
- showlegend=True,
229
- legend=dict(
230
- yanchor='bottom',
231
- y=0.01,
232
- xanchor='right',
233
- x=0.99,
234
- ),
235
- scene=dict(
236
- aspectmode='manual',
237
- aspectratio=dict(x=1, y=1, z=1.0),
238
- camera=dict(
239
- eye=dict(x=base_radius - 1.6, y=0.0, z=0.6),
240
- center=dict(x=0.0, y=0.0, z=0.0),
241
- up=dict(x=0.0, y=0.0, z=1.0)),
242
- xaxis_title='',
243
- yaxis_title='',
244
- zaxis_title='',
245
- xaxis=dict(
246
- range=[-scene_bounds, scene_bounds],
247
- showticklabels=False,
248
- showgrid=True,
249
- zeroline=False,
250
- showbackground=True,
251
- showspikes=False,
252
- showline=False,
253
- ticks=''),
254
- yaxis=dict(
255
- range=[-scene_bounds, scene_bounds],
256
- showticklabels=False,
257
- showgrid=True,
258
- zeroline=False,
259
- showbackground=True,
260
- showspikes=False,
261
- showline=False,
262
- ticks=''),
263
- zaxis=dict(
264
- range=[-scene_bounds, scene_bounds],
265
- showticklabels=False,
266
- showgrid=True,
267
- zeroline=False,
268
- showbackground=True,
269
- showspikes=False,
270
- showline=False,
271
- ticks='')))
272
-
273
- self._fig = fig
274
- return fig
275
-
276
-
277
- def preprocess_image(models, input_im, preprocess):
278
- '''
279
- :param input_im (PIL Image).
280
- :return input_im (H, W, 3) array in [0, 1].
281
- '''
282
-
283
- print('old input_im:', input_im.size)
284
- start_time = time.time()
285
-
286
- if preprocess:
287
- input_im = load_and_preprocess(models['carvekit'], input_im)
288
- input_im = (input_im / 255.0).astype(np.float32)
289
- # (H, W, 3) array in [0, 1].
290
-
291
- else:
292
- input_im = input_im.resize([256, 256], Image.Resampling.LANCZOS)
293
- input_im = np.asarray(input_im, dtype=np.float32) / 255.0
294
- # (H, W, 4) array in [0, 1].
295
-
296
- # old method: thresholding background, very important
297
- # input_im[input_im[:, :, -1] <= 0.9] = [1., 1., 1., 1.]
298
-
299
- # new method: apply correct method of compositing to avoid sudden transitions / thresholding
300
- # (smoothly transition foreground to white background based on alpha values)
301
- alpha = input_im[:, :, 3:4]
302
- white_im = np.ones_like(input_im)
303
- input_im = alpha * input_im + (1.0 - alpha) * white_im
304
-
305
- input_im = input_im[:, :, 0:3]
306
- # (H, W, 3) array in [0, 1].
307
-
308
- print(f'Infer foreground mask (preprocess_image) took {time.time() - start_time:.3f}s.')
309
- print('new input_im:', lo(input_im))
310
-
311
- return input_im
312
-
313
-
314
- def main_run(models, device, cam_vis, return_what,
315
- x=0.0, y=0.0, z=0.0,
316
- raw_im=None, preprocess=True,
317
- scale=3.0, n_samples=4, ddim_steps=50, ddim_eta=1.0,
318
- precision='fp32', h=256, w=256):
319
- '''
320
- :param raw_im (PIL Image).
321
- '''
322
-
323
- safety_checker_input = models['clip_fe'](raw_im, return_tensors='pt').to(device)
324
- (image, has_nsfw_concept) = models['nsfw'](
325
- images=np.ones((1, 3)), clip_input=safety_checker_input.pixel_values)
326
- print('has_nsfw_concept:', has_nsfw_concept)
327
- if np.any(has_nsfw_concept):
328
- print('NSFW content detected.')
329
- to_return = [None] * 10
330
- description = ('### <span style="color:red"> Unfortunately, '
331
- 'potential NSFW content was detected, '
332
- 'which is not supported by our model. '
333
- 'Please try again with a different image. </span>')
334
- if 'angles' in return_what:
335
- to_return[0] = 0.0
336
- to_return[1] = 0.0
337
- to_return[2] = 0.0
338
- to_return[3] = description
339
- else:
340
- to_return[0] = description
341
- return to_return
342
-
343
- else:
344
- print('Safety check passed.')
345
-
346
- input_im = preprocess_image(models, raw_im, preprocess)
347
-
348
- # if np.random.rand() < 0.3:
349
- # description = ('Unfortunately, a human, a face, or potential NSFW content was detected, '
350
- # 'which is not supported by our model.')
351
- # if vis_only:
352
- # return (None, None, description)
353
- # else:
354
- # return (None, None, None, description)
355
-
356
- show_in_im1 = (input_im * 255.0).astype(np.uint8)
357
- show_in_im2 = Image.fromarray(show_in_im1)
358
-
359
- if 'rand' in return_what:
360
- x = int(np.round(np.arcsin(np.random.uniform(-1.0, 1.0)) * 160.0 / np.pi)) # [-80, 80].
361
- y = int(np.round(np.random.uniform(-150.0, 150.0)))
362
- z = 0.0
363
-
364
- cam_vis.polar_change(x)
365
- cam_vis.azimuth_change(y)
366
- cam_vis.radius_change(z)
367
- cam_vis.encode_image(show_in_im1)
368
- new_fig = cam_vis.update_figure()
369
-
370
- if 'vis' in return_what:
371
- description = ('The viewpoints are visualized on the top right. '
372
- 'Click Run Generation to update the results on the bottom right.')
373
-
374
- if 'angles' in return_what:
375
- return (x, y, z, description, new_fig, show_in_im2)
376
- else:
377
- return (description, new_fig, show_in_im2)
378
-
379
- elif 'gen' in return_what:
380
- input_im = transforms.ToTensor()(input_im).unsqueeze(0).to(device)
381
- input_im = input_im * 2 - 1
382
- input_im = transforms.functional.resize(input_im, [h, w])
383
-
384
- sampler = DDIMSampler(models['turncam'])
385
- # used_x = -x # NOTE: Polar makes more sense in Basile's opinion this way!
386
- used_x = x # NOTE: Set this way for consistency.
387
- x_samples_ddim = sample_model(input_im, models['turncam'], sampler, precision, h, w,
388
- ddim_steps, n_samples, scale, ddim_eta, used_x, y, z)
389
-
390
- output_ims = []
391
- for x_sample in x_samples_ddim:
392
- x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
393
- output_ims.append(Image.fromarray(x_sample.astype(np.uint8)))
394
-
395
- description = None
396
-
397
- if 'angles' in return_what:
398
- return (x, y, z, description, new_fig, show_in_im2, output_ims)
399
- else:
400
- return (description, new_fig, show_in_im2, output_ims)
401
-
402
-
403
- def calc_cam_cone_pts_3d(polar_deg, azimuth_deg, radius_m, fov_deg):
404
- '''
405
- :param polar_deg (float).
406
- :param azimuth_deg (float).
407
- :param radius_m (float).
408
- :param fov_deg (float).
409
- :return (5, 3) array of float with (x, y, z).
410
- '''
411
- polar_rad = np.deg2rad(polar_deg)
412
- azimuth_rad = np.deg2rad(azimuth_deg)
413
- fov_rad = np.deg2rad(fov_deg)
414
- polar_rad = -polar_rad # NOTE: Inverse of how used_x relates to x.
415
-
416
- # Camera pose center:
417
- cam_x = radius_m * np.cos(azimuth_rad) * np.cos(polar_rad)
418
- cam_y = radius_m * np.sin(azimuth_rad) * np.cos(polar_rad)
419
- cam_z = radius_m * np.sin(polar_rad)
420
-
421
- # Obtain four corners of camera frustum, assuming it is looking at origin.
422
- # First, obtain camera extrinsics (rotation matrix only):
423
- camera_R = np.array([[np.cos(azimuth_rad) * np.cos(polar_rad),
424
- -np.sin(azimuth_rad),
425
- -np.cos(azimuth_rad) * np.sin(polar_rad)],
426
- [np.sin(azimuth_rad) * np.cos(polar_rad),
427
- np.cos(azimuth_rad),
428
- -np.sin(azimuth_rad) * np.sin(polar_rad)],
429
- [np.sin(polar_rad),
430
- 0.0,
431
- np.cos(polar_rad)]])
432
- # print('camera_R:', lo(camera_R).v)
433
-
434
- # Multiply by corners in camera space to obtain go to space:
435
- corn1 = [-1.0, np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)]
436
- corn2 = [-1.0, -np.tan(fov_rad / 2.0), np.tan(fov_rad / 2.0)]
437
- corn3 = [-1.0, -np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)]
438
- corn4 = [-1.0, np.tan(fov_rad / 2.0), -np.tan(fov_rad / 2.0)]
439
- corn1 = np.dot(camera_R, corn1)
440
- corn2 = np.dot(camera_R, corn2)
441
- corn3 = np.dot(camera_R, corn3)
442
- corn4 = np.dot(camera_R, corn4)
443
-
444
- # Now attach as offset to actual 3D camera position:
445
- corn1 = np.array(corn1) / np.linalg.norm(corn1, ord=2)
446
- corn_x1 = cam_x + corn1[0]
447
- corn_y1 = cam_y + corn1[1]
448
- corn_z1 = cam_z + corn1[2]
449
- corn2 = np.array(corn2) / np.linalg.norm(corn2, ord=2)
450
- corn_x2 = cam_x + corn2[0]
451
- corn_y2 = cam_y + corn2[1]
452
- corn_z2 = cam_z + corn2[2]
453
- corn3 = np.array(corn3) / np.linalg.norm(corn3, ord=2)
454
- corn_x3 = cam_x + corn3[0]
455
- corn_y3 = cam_y + corn3[1]
456
- corn_z3 = cam_z + corn3[2]
457
- corn4 = np.array(corn4) / np.linalg.norm(corn4, ord=2)
458
- corn_x4 = cam_x + corn4[0]
459
- corn_y4 = cam_y + corn4[1]
460
- corn_z4 = cam_z + corn4[2]
461
-
462
- xs = [cam_x, corn_x1, corn_x2, corn_x3, corn_x4]
463
- ys = [cam_y, corn_y1, corn_y2, corn_y3, corn_y4]
464
- zs = [cam_z, corn_z1, corn_z2, corn_z3, corn_z4]
465
-
466
- return np.array([xs, ys, zs]).T
467
-
468
-
469
- def run_demo(
470
- device_idx=_GPU_INDEX,
471
- ckpt='105000.ckpt',
472
- config='configs/sd-objaverse-finetune-c_concat-256.yaml'):
473
-
474
- print('sys.argv:', sys.argv)
475
- if len(sys.argv) > 1:
476
- print('old device_idx:', device_idx)
477
- device_idx = int(sys.argv[1])
478
- print('new device_idx:', device_idx)
479
-
480
- device = f'cuda:{device_idx}'
481
- config = OmegaConf.load(config)
482
-
483
- # Instantiate all models beforehand for efficiency.
484
- models = dict()
485
- print('Instantiating LatentDiffusion...')
486
- models['turncam'] = load_model_from_config(config, ckpt, device=device)
487
- print('Instantiating Carvekit HiInterface...')
488
- models['carvekit'] = create_carvekit_interface()
489
- print('Instantiating StableDiffusionSafetyChecker...')
490
- models['nsfw'] = StableDiffusionSafetyChecker.from_pretrained(
491
- 'CompVis/stable-diffusion-safety-checker').to(device)
492
- print('Instantiating AutoFeatureExtractor...')
493
- models['clip_fe'] = AutoFeatureExtractor.from_pretrained(
494
- 'CompVis/stable-diffusion-safety-checker')
495
-
496
- # Reduce NSFW false positives.
497
- # NOTE: At the time of writing, and for diffusers 0.12.1, the default parameters are:
498
- # models['nsfw'].concept_embeds_weights:
499
- # [0.1800, 0.1900, 0.2060, 0.2100, 0.1950, 0.1900, 0.1940, 0.1900, 0.1900, 0.2200, 0.1900,
500
- # 0.1900, 0.1950, 0.1984, 0.2100, 0.2140, 0.2000].
501
- # models['nsfw'].special_care_embeds_weights:
502
- # [0.1950, 0.2000, 0.2200].
503
- # We multiply all by some factor > 1 to make them less likely to be triggered.
504
- models['nsfw'].concept_embeds_weights *= 1.07
505
- models['nsfw'].special_care_embeds_weights *= 1.07
506
-
507
- with open('instructions.md', 'r') as f:
508
- article = f.read()
509
-
510
- # Compose demo layout & data flow.
511
- demo = gr.Blocks(title=_TITLE)
512
-
513
- with demo:
514
- gr.Markdown('# ' + _TITLE)
515
- gr.Markdown(_DESCRIPTION)
516
-
517
- with gr.Row():
518
- with gr.Column(scale=0.9, variant='panel'):
519
-
520
- image_block = gr.Image(type='pil', image_mode='RGBA',
521
- label='Input image of single object')
522
- preprocess_chk = gr.Checkbox(
523
- True, label='Preprocess image automatically (remove background and recenter object)')
524
- # info='If enabled, the uploaded image will be preprocessed to remove the background and recenter the object by cropping and/or padding as necessary. '
525
- # 'If disabled, the image will be used as-is, *BUT* a fully transparent or white background is required.'),
526
-
527
- gr.Markdown('*Try camera position presets:*')
528
- with gr.Row():
529
- left_btn = gr.Button('View from the Left', variant='primary')
530
- above_btn = gr.Button('View from Above', variant='primary')
531
- right_btn = gr.Button('View from the Right', variant='primary')
532
- with gr.Row():
533
- random_btn = gr.Button('Random Rotation', variant='primary')
534
- below_btn = gr.Button('View from Below', variant='primary')
535
- behind_btn = gr.Button('View from Behind', variant='primary')
536
-
537
- gr.Markdown('*Control camera position manually:*')
538
- polar_slider = gr.Slider(
539
- -90, 90, value=0, step=5, label='Polar angle (vertical rotation in degrees)')
540
- # info='Positive values move the camera down, while negative values move the camera up.')
541
- azimuth_slider = gr.Slider(
542
- -180, 180, value=0, step=5, label='Azimuth angle (horizontal rotation in degrees)')
543
- # info='Positive values move the camera right, while negative values move the camera left.')
544
- radius_slider = gr.Slider(
545
- -0.5, 0.5, value=0.0, step=0.1, label='Zoom (relative distance from center)')
546
- # info='Positive values move the camera further away, while negative values move the camera closer.')
547
-
548
- samples_slider = gr.Slider(1, 8, value=4, step=1,
549
- label='Number of samples to generate')
550
-
551
- with gr.Accordion('Advanced options', open=False):
552
- scale_slider = gr.Slider(0, 30, value=3, step=1,
553
- label='Diffusion guidance scale')
554
- steps_slider = gr.Slider(5, 200, value=75, step=5,
555
- label='Number of diffusion inference steps')
556
-
557
- with gr.Row():
558
- vis_btn = gr.Button('Visualize Angles', variant='secondary')
559
- run_btn = gr.Button('Run Generation', variant='primary')
560
-
561
- desc_output = gr.Markdown('The results will appear on the right.', visible=_SHOW_DESC)
562
-
563
- with gr.Column(scale=1.1, variant='panel'):
564
-
565
- vis_output = gr.Plot(
566
- label='Relationship between input (green) and output (blue) camera poses')
567
-
568
- gen_output = gr.Gallery(label='Generated images from specified new viewpoint')
569
- gen_output.style(grid=2)
570
-
571
- preproc_output = gr.Image(type='pil', image_mode='RGB',
572
- label='Preprocessed input image', visible=_SHOW_INTERMEDIATE)
573
-
574
- gr.Markdown(article)
575
-
576
- # NOTE: I am forced to update vis_output for these preset buttons,
577
- # because otherwise the gradio plot always resets the plotly 3D viewpoint for some reason,
578
- # which might confuse the user into thinking that the plot has been updated too.
579
-
580
- # OLD 1:
581
- # left_btn.click(fn=lambda: [0.0, -90.0], #, 0.0],
582
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #], radius_slider])
583
- # above_btn.click(fn=lambda: [90.0, 0.0], #, 0.0],
584
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
585
- # right_btn.click(fn=lambda: [0.0, 90.0], #, 0.0],
586
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
587
- # random_btn.click(fn=lambda: [int(np.round(np.random.uniform(-60.0, 60.0))),
588
- # int(np.round(np.random.uniform(-150.0, 150.0)))], #, 0.0],
589
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
590
- # below_btn.click(fn=lambda: [-90.0, 0.0], #, 0.0],
591
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
592
- # behind_btn.click(fn=lambda: [0.0, 180.0], #, 0.0],
593
- # inputs=[], outputs=[polar_slider, azimuth_slider]), #, radius_slider])
594
-
595
- # OLD 2:
596
- # preset_text = ('You have selected a preset target camera view. '
597
- # 'Now click Run Generation to update the results!')
598
-
599
- # left_btn.click(fn=lambda: [0.0, -90.0, None, preset_text],
600
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
601
- # above_btn.click(fn=lambda: [90.0, 0.0, None, preset_text],
602
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
603
- # right_btn.click(fn=lambda: [0.0, 90.0, None, preset_text],
604
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
605
- # random_btn.click(fn=lambda: [int(np.round(np.random.uniform(-60.0, 60.0))),
606
- # int(np.round(np.random.uniform(-150.0, 150.0))),
607
- # None, preset_text],
608
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
609
- # below_btn.click(fn=lambda: [-90.0, 0.0, None, preset_text],
610
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
611
- # behind_btn.click(fn=lambda: [0.0, 180.0, None, preset_text],
612
- # inputs=[], outputs=[polar_slider, azimuth_slider, vis_output, desc_output])
613
-
614
- # OLD 3 (does not work at all):
615
- # def a():
616
- # polar_slider.value = 77.7
617
- # polar_slider.postprocess(77.7)
618
- # print('testa')
619
- # left_btn.click(fn=a)
620
-
621
- cam_vis = CameraVisualizer(vis_output)
622
-
623
- vis_btn.click(fn=partial(main_run, models, device, cam_vis, 'vis'),
624
- inputs=[polar_slider, azimuth_slider, radius_slider,
625
- image_block, preprocess_chk],
626
- outputs=[desc_output, vis_output, preproc_output])
627
-
628
- run_btn.click(fn=partial(main_run, models, device, cam_vis, 'gen'),
629
- inputs=[polar_slider, azimuth_slider, radius_slider,
630
- image_block, preprocess_chk,
631
- scale_slider, samples_slider, steps_slider],
632
- outputs=[desc_output, vis_output, preproc_output, gen_output])
633
-
634
- # NEW:
635
- preset_inputs = [image_block, preprocess_chk,
636
- scale_slider, samples_slider, steps_slider]
637
- preset_outputs = [polar_slider, azimuth_slider, radius_slider,
638
- desc_output, vis_output, preproc_output, gen_output]
639
- left_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
640
- 0.0, -90.0, 0.0),
641
- inputs=preset_inputs, outputs=preset_outputs)
642
- above_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
643
- -90.0, 0.0, 0.0),
644
- inputs=preset_inputs, outputs=preset_outputs)
645
- right_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
646
- 0.0, 90.0, 0.0),
647
- inputs=preset_inputs, outputs=preset_outputs)
648
- random_btn.click(fn=partial(main_run, models, device, cam_vis, 'rand_angles_gen',
649
- -1.0, -1.0, -1.0),
650
- inputs=preset_inputs, outputs=preset_outputs)
651
- below_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
652
- 90.0, 0.0, 0.0),
653
- inputs=preset_inputs, outputs=preset_outputs)
654
- behind_btn.click(fn=partial(main_run, models, device, cam_vis, 'angles_gen',
655
- 0.0, 180.0, 0.0),
656
- inputs=preset_inputs, outputs=preset_outputs)
657
-
658
- demo.launch(enable_queue=True, share=True)
659
-
660
-
661
- if __name__ == '__main__':
662
-
663
- fire.Fire(run_demo)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
taming-transformers ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 3ba01b241669f5ade541ce990f7650a3b8f65318