imsuperkong commited on
Commit
55ef7fd
1 Parent(s): 2abfd01

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +456 -0
app.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+
5
+ import sd.gradio_utils as gradio_utils
6
+
7
+ import os
8
+ import cv2
9
+ import argparse
10
+ import ipdb
11
+
12
+ import argparse
13
+ from tqdm import tqdm
14
+ from diffusers import DDIMScheduler
15
+ from diffusers import DDIMScheduler, DDPMScheduler
16
+
17
+ from sd.core import DDIMBackward, DDPM_forward
18
+
19
+ torch.backends.cudnn.enabled = True
20
+ torch.backends.cudnn.benchmark = True
21
+
22
+ def slerp(R_target, rotation_speed):
23
+ # Compute the angle of rotation from the rotation matrix
24
+ angle = np.arccos((np.trace(R_target) - 1) / 2)
25
+
26
+ # Handle the case where angle is very small (no significant rotation)
27
+ if angle < 1e-6:
28
+ return np.eye(3)
29
+
30
+ # Normalize the angle based on rotation_speed
31
+ normalized_angle = angle * rotation_speed
32
+
33
+ # Axis of rotation
34
+ axis = np.array([R_target[2, 1] - R_target[1, 2],
35
+ R_target[0, 2] - R_target[2, 0],
36
+ R_target[1, 0] - R_target[0, 1]])
37
+ axis = axis / np.linalg.norm(axis)
38
+
39
+ # Return the interpolated rotation matrix
40
+ return cv2.Rodrigues(axis * normalized_angle)[0]
41
+
42
+
43
+ def compute_extrinsic_parameters(clicked_point, depth, intrinsic_matrix, rotation_speed, step_x=0, step_y=0, step_z=0):
44
+ # Normalize the clicked point
45
+ x,y = clicked_point
46
+ x = int(x)
47
+ y = int(y)
48
+ x_normalized = (x - intrinsic_matrix[0, 2]) / intrinsic_matrix[0, 0]
49
+ y_normalized = (y - intrinsic_matrix[1, 2]) / intrinsic_matrix[1, 1]
50
+
51
+ # Depth at the clicked point
52
+ try:
53
+ z = depth[y, x]
54
+ except Exception:
55
+ ipdb.set_trace()
56
+
57
+ # Direction vector in camera coordinates
58
+ direction_vector = np.array([x_normalized * z, y_normalized * z, z])
59
+
60
+ # Calculate rotation angles to bring the clicked point to the center
61
+ angle_y = -np.arctan2(direction_vector[1], direction_vector[2]) # Rotation about Y-axis
62
+ angle_x = np.arctan2(direction_vector[0], direction_vector[2]) # Rotation about X-axis
63
+
64
+ # Apply rotation speed
65
+ angle_y *= rotation_speed
66
+ angle_x *= rotation_speed
67
+
68
+ # Compute rotation matrices
69
+ R_x = cv2.Rodrigues(np.array([1, 0, 0]) * angle_x)[0]
70
+ R_y = cv2.Rodrigues(np.array([0, 1, 0]) * angle_y)[0]
71
+ R = R_y @ R_x
72
+
73
+ # Compute rotation matrix to align direction vector with principal axis
74
+ T = np.array([step_x, -step_y, -step_z])
75
+
76
+ # Create extrinsic matrix
77
+ extrinsic_matrix = np.eye(4)
78
+ extrinsic_matrix[:3, :3] = R
79
+ extrinsic_matrix[:3, 3] = T
80
+
81
+ return extrinsic_matrix
82
+
83
+ @torch.no_grad()
84
+ def encode_imgs(imgs):
85
+ imgs = 2 * imgs - 1
86
+ posterior = pipe.vae.encode(imgs).latent_dist
87
+ latents = posterior.mean * 0.18215
88
+ return latents
89
+
90
+ @torch.no_grad()
91
+ def decode_latents(latents):
92
+ latents = 1 / 0.18215 * latents
93
+ imgs = pipe.vae.decode(latents).sample
94
+ imgs = (imgs / 2 + 0.5).clamp(0, 1)
95
+ return imgs
96
+
97
+ @torch.no_grad()
98
+ def ddim_inversion(latent, cond, stop_t=1000, start_t=-1):
99
+ timesteps = reversed(pipe.scheduler.timesteps)
100
+ pipe.scheduler.set_timesteps(num_inference_steps)
101
+ for i, t in enumerate(tqdm(timesteps)):
102
+ if t >= stop_t:
103
+ break
104
+ if t <=start_t:
105
+ continue
106
+ cond_batch = cond.repeat(latent.shape[0], 1, 1)
107
+
108
+ alpha_prod_t = pipe.scheduler.alphas_cumprod[t]
109
+ alpha_prod_t_prev = (
110
+ pipe.scheduler.alphas_cumprod[timesteps[i - 1]]
111
+ if i > 0 else pipe.scheduler.final_alpha_cumprod
112
+ )
113
+
114
+ mu = alpha_prod_t ** 0.5
115
+ mu_prev = alpha_prod_t_prev ** 0.5
116
+ sigma = (1 - alpha_prod_t) ** 0.5
117
+ sigma_prev = (1 - alpha_prod_t_prev) ** 0.5
118
+
119
+ eps = pipe.unet(latent, t, encoder_hidden_states=cond_batch).sample
120
+
121
+ pred_x0 = (latent - sigma_prev * eps) / mu_prev
122
+ latent = mu * pred_x0 + sigma * eps
123
+
124
+ return latent
125
+
126
+ @torch.no_grad()
127
+ def get_text_embeds(prompt, negative_prompt='', batch_size=1):
128
+ text_input = pipe.tokenizer(prompt, padding='max_length', max_length=77, truncation=True, return_tensors='pt')
129
+ text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0]
130
+
131
+ uncond_input = pipe.tokenizer(negative_prompt, padding='max_length', max_length=77, truncation=True, return_tensors='pt')
132
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
133
+
134
+ # cat for final embeddings
135
+ text_embeddings = torch.cat([uncond_embeddings] * batch_size + [text_embeddings] * batch_size).to(torch_dtype)
136
+ return text_embeddings
137
+
138
+ def save_video(frames, fps=10, out_path='output/output.mp4'):
139
+ video_dims = (512, 512)
140
+ fourcc = cv2.VideoWriter_fourcc(*'MP4V')
141
+ video = cv2.VideoWriter(out_path,fourcc, fps, video_dims)
142
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
143
+ for frame in frames:
144
+ video.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
145
+ video.release()
146
+
147
+ def draw_prompt(prompt):
148
+ return prompt
149
+
150
+ def to_image(tensor):
151
+ tensor = tensor.squeeze(0).permute(1, 2, 0)
152
+ arr = tensor.detach().cpu().numpy()
153
+ arr = (arr - arr.min()) / (arr.max() - arr.min())
154
+ arr = arr * 255
155
+ return arr.astype('uint8')
156
+
157
+ def add_points_to_image(image, points):
158
+ image = gradio_utils.draw_handle_target_points(image, points, 5)
159
+ return image
160
+
161
+
162
+ def on_click(state, seed, count, prompt, neg_prompt, speed_r, speed_x, speed_y, speed_z, t1, t2, t3, lr, guidance_weight,attn,threshold, early_stop, evt: gr.SelectData):
163
+ end_id = int(t1)
164
+ start_id=int(t2)
165
+ startstart_id = int(t3)
166
+ timesteps = reversed(ddim_scheduler.timesteps)
167
+ end_t = timesteps[end_id]
168
+ start_t = timesteps[start_id]
169
+ startstart_t = timesteps[startstart_id]
170
+ attn=float(attn)
171
+ cfg_norm=False
172
+ cfg_decay=False
173
+ guidance_loss_scale = float(guidance_weight)
174
+ lr = float(lr)
175
+ threshold = int(threshold)
176
+ up_ft_indexes = 2
177
+ early_stop = int(early_stop)
178
+ generator = torch.Generator(device).manual_seed(int(seed)) # 19491001
179
+
180
+ state['direction_offset'] = [int(evt.index[0]), int(evt.index[1])]
181
+ cond = pipe._encode_prompt(prompt, device, 1, True, '')
182
+ for _ in range(int(count)):
183
+ image = state['img']
184
+ img_tensor = torch.from_numpy(np.array(image) / 255.).to(device).to(torch_dtype).permute(2,0,1).unsqueeze(0)
185
+ _,_,depth = pipe.midas_model(np.array(image))
186
+
187
+ centered = is_centered(state['direction_offset'])
188
+ if centered:
189
+ extrinsic = compute_extrinsic_parameters(state['direction_offset'], depth, intrinsic, rotation_speed=float(0), step_z=float(speed_z), step_x=float(speed_x), step_y=float(speed_y))
190
+ state['centered'] = centered
191
+ else:
192
+ extrinsic = compute_extrinsic_parameters(state['direction_offset'], depth, intrinsic, rotation_speed=float(speed_r), step_z=float(speed_z), step_x=float(speed_x), step_y=float(speed_y))
193
+
194
+ this_latent = encode_imgs(img_tensor)
195
+ this_ddim_inv_noise_end = ddim_inversion(this_latent, cond[1:], stop_t=end_t)
196
+ this_ddim_inv_noise_start = ddim_inversion(this_latent, cond[1:], stop_t=startstart_t)
197
+
198
+ wrapped_this_ddim_inv_noise_end = pipe.midas_model.wrap_img_tensor_w_fft_ext(this_ddim_inv_noise_end.to(torch_dtype),
199
+ torch.from_numpy(depth).to(device).to(torch_dtype),
200
+ intrinsic,
201
+ extrinsic[:3,:3], extrinsic[:3,3], threshold=threshold).to(torch_dtype)
202
+
203
+ wrapped_this_ddim_inv_noise_start = ddim_inversion(wrapped_this_ddim_inv_noise_end, cond[1:], stop_t=start_t, start_t=end_t,)
204
+ wrapped_this_ddim_inv_noise_start = DDPM_forward(wrapped_this_ddim_inv_noise_start, t_start=start_t, delta_t=(startstart_id-start_id)*20,
205
+ ddpm_scheduler=ddpm_scheduler, generator=generator)
206
+
207
+ new_img = pipe.denoise_w_injection(
208
+ prompt, generator=generator, num_inference_steps=num_inference_steps,
209
+ latents=torch.cat([this_ddim_inv_noise_start, wrapped_this_ddim_inv_noise_start], dim=0), t_start=startstart_t,
210
+ latent_mask=torch.ones_like(this_latent[0,0,...], device=device,
211
+ ).unsqueeze(0),
212
+ f=0, attn=attn, guidance_scale=7.5, negative_prompt=neg_prompt,
213
+ guidance_loss_scale=guidance_loss_scale, early_stop=early_stop, up_ft_indexes=[up_ft_indexes],
214
+ cfg_norm=cfg_norm, cfg_decay=cfg_decay, lr=lr,
215
+ intrinsic=intrinsic, extrinsic=extrinsic, threshold=threshold,depth=depth,
216
+ ).images[1]
217
+
218
+ new_img = np.array(new_img).astype(np.uint8)
219
+ state['img'] = new_img
220
+
221
+ state['img_his'].append(new_img)
222
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 1.
223
+ state['depth_his'].append(depth)
224
+
225
+ return new_img, depth, state['img_his'], state
226
+
227
+ def is_centered(clicked_point, image_dimensions=(512, 512), threshold=5):
228
+ image_center = [dim // 2 for dim in image_dimensions]
229
+ return all(abs(clicked_point[i] - image_center[i]) <= threshold for i in range(2))
230
+
231
+
232
+ def gen_img(prompt, neg_prompt, state, seed):
233
+ generator = torch.Generator(device).manual_seed(int(seed)) # 19491001
234
+ img = pipe(
235
+ prompt, generator=generator, num_inference_steps=num_inference_steps, negative_prompt=neg_prompt,
236
+ ).images[0]
237
+ img_array = np.array(img)
238
+ _,_,depth = pipe.midas_model(img_array)
239
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 1.
240
+
241
+ state['img_his'] = [img_array]
242
+ state['depth_his'] = [depth]
243
+ try:
244
+ state['ori_img'] = img_array
245
+ state['img'] = img_array
246
+ except Exception:
247
+ ipdb.set_trace()
248
+ return img_array, depth, [img_array], state
249
+
250
+ def on_undo(state):
251
+ if len(state['img_his'])>1:
252
+ del state['img_his'][-1]
253
+ del state['depth_his'][-1]
254
+ image = state['img_his'][-1]
255
+ depth = state['depth_his'][-1]
256
+ else:
257
+ image = state['img_his'][-1]
258
+ depth = state['depth_his'][-1]
259
+ state['img'] = image
260
+ return image, depth, state['img_his'], state
261
+
262
+ def on_reset(state):
263
+ image = state['img_his'][0]
264
+ depth = state['depth_his'][0]
265
+ state['img'] = image
266
+ state['img_his'] = [image]
267
+ state['depth_his'] = [depth]
268
+ return image, depth, state['img_his'], state
269
+
270
+ def get_prompt(text):
271
+ return text
272
+
273
+ def on_save(state, video_name):
274
+ save_video(state['img_his'], fps=5, out_path=f'output/{video_name}.mp4')
275
+
276
+ def on_seed(seed):
277
+ return int(seed)
278
+
279
+ def main(args):
280
+ with gr.Blocks() as demo:
281
+ gr.Markdown(
282
+ """
283
+ # DreamDrone
284
+
285
+ Official implementation of [DreamDrone](https://hyokong.github.io/publications/dreamdrone-page/).
286
+
287
+ **TL;DR:** Navigate dreamscapes with a ***click*** – your chosen point guides the drone's flight in a thrilling visual journey.
288
+
289
+ ## Tutorial
290
+
291
+ 1. Enter your prompt (and a negative prompt, if necessary) in the textbox, then click the `Generate first image` button.
292
+ 2. Adjust the camera's moving speed in the `Direction` panel and set hyperparameters in the `Hyper params` panel.
293
+ 3. Click on the generated image to make the camera fly towards the clicked direction.
294
+ 4. The generated images will be displayed in the gallery at the bottom. You can view these images by clicking on them in the gallery or by using the left/right arrow buttons.
295
+
296
+ ## Hints
297
+
298
+ - You can set the number of images to generate after clicking on an image, for convenience.
299
+ - Our system uses a right-hand coordinate system, with the Z-axis pointing into the image.
300
+ - The rotation speed determines how quickly the camera moves towards the clicked direction (rotation only, no translation). Increase this if you need faster camera pose changes.
301
+ - The Speed XYZ-axis controls the camera's movement along the X, Y, and Z axes. Adjust these parameters for different movement styles, similar to a camera arm.
302
+ - $t_1$ represents the timestep that wraps the latent code.
303
+ - Noise is added from $t_1$ to $t_3$. Between $t_1$ and $t_2$, noise is sourced from a pretrained diffusion U-Net. From $t_2$ to $t_3$, random Gaussian noise is used.
304
+ - The `Learning rate` and `Feature Correspondence Guidance` control the feature-correspondence guidance weight during the denoising process (from timestep $t_3$ to $0$).
305
+ - The `KV injection` parameter adjusts the extent of key and value injection from the current frame to the next.
306
+
307
+ > If you encounter any problems, please open an issue. Also, don't forget to star the [Official Github Repo](https://github.com/HyoKong/DreamDrone).
308
+
309
+ ***Without further ado, welcome to DreamDrone – enjoy piloting your virtual drone through imaginative landscapes!***
310
+
311
+
312
+ """,
313
+ )
314
+ img = np.zeros((512, 512, 3)).astype(np.uint8)
315
+ depth_img = np.zeros((512, 512, 3)).astype(np.uint8)
316
+ intrinsic_matrix = np.array([[1000, 0, 512/2],
317
+ [0, 1000, 512/2],
318
+ [0, 0, 1]]) # Example intrinsic matrix
319
+ extrinsic_matrix = np.array([[1.0, 0.0, 0.0, 0.0],
320
+ [0.0, 1.0, 0.0, 0.0],
321
+ [0.0, 0.0, 1.0, 0.0]],
322
+ dtype=np.float32)
323
+ direction_offset = (255, 255)
324
+ state = gr.State({
325
+ 'ori_img': img,
326
+ 'img': None,
327
+ 'centered': False,
328
+ 'img_his': [],
329
+ 'depth_his': [],
330
+ 'intrinsic': intrinsic_matrix,
331
+ 'extrinsic': extrinsic_matrix,
332
+ 'direction_offset': direction_offset
333
+ })
334
+
335
+ with gr.Row():
336
+ with gr.Column(scale=0.2):
337
+ with gr.Accordion("Direction"):
338
+ speed_r = gr.Number(value=0.1, label='Rotation Speed', step=0.01, minimum=0, maximum=1)
339
+ speed_x = gr.Number(value=0, label='Speed X-axis', step=1, minimum=-10, maximum=20.0)
340
+ speed_y = gr.Number(value=0, label='Speed Y-axis', step=1, minimum=-10, maximum=20.0)
341
+ speed_z = gr.Number(value=5, label='Speed Z-axis', step=1, minimum=-10, maximum=20.0)
342
+ with gr.Accordion('Hyper params'):
343
+ with gr.Row():
344
+ count = gr.Number(value=5, label='Num. of generated images', step=1, minimum=1, maximum=10, precision=0)
345
+ seed = gr.Number(value=19491000, label='Seed', precision=0)
346
+ t1 = gr.Slider(1, 49, 2, step=1, label='t1')
347
+ t2 = gr.Slider(1, 49, 12, step=1, label='t2')
348
+ t3 = gr.Slider(1, 49, 27, step=1, label='t3')
349
+ lr = gr.Slider(0, 500, 300, step=1, label='Learning rate')
350
+ guidance_weight = gr.Slider(0, 10, 0.1, step=0.1, label='Feature correspondance guidance')
351
+ attn = gr.Slider(0, 1, 0.5, step=0.1, label='KV injection')
352
+ threshold = gr.Slider(0, 31, 20, step=1, label='Threshold of low-pass filter')
353
+ early_stop = gr.Slider(0, 50, 48, step=1, label='Early stop timestep for feature-correspondance guidance')
354
+ video_name = gr.Textbox(
355
+ label="Saved video name", show_label=True, max_lines=1, placeholder='saved video name', value='output',
356
+ ).style()
357
+
358
+ with gr.Column():
359
+ with gr.Box():
360
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
361
+ text = gr.Textbox(
362
+ label="Enter your prompt", show_label=False, max_lines=1, placeholder='Enter your prompt', value='Backyards of Old Houses in Antwerp in the Snow, van Gogh',
363
+ ).style(
364
+ border=(True, False, True, True),
365
+ rounded=(True, False, False, True),
366
+ container=False,
367
+ )
368
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
369
+ with gr.Column(scale=0.8):
370
+ neg_text = gr.Textbox(
371
+ label="Enter your negative prompt", show_label=False, max_lines=1, value='', placeholder='Enter your negative prompt',
372
+ ).style(
373
+ border=(True, False, True, True),
374
+ rounded=(True, False, False, True),
375
+ container=False,
376
+ )
377
+ with gr.Column(scale=0.2):
378
+ gen_btn = gr.Button("Generate first image").style(
379
+ margin=False,
380
+ rounded=(False, True, True, False),
381
+ )
382
+
383
+ with gr.Box():
384
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
385
+ with gr.Column():
386
+ with gr.Tab('Current view'):
387
+ image = gr.Image(img).style(height=600, width=600)
388
+ with gr.Column():
389
+ with gr.Tab('Depth'):
390
+ depth_image = gr.Image(depth_img).style(height=600, width=600)
391
+ with gr.Row():
392
+ with gr.Column(min_width=100):
393
+ reset_btn = gr.Button('Clear All')
394
+ with gr.Column(min_width=100):
395
+ undo_btn = gr.Button('Undo Last')
396
+ with gr.Column(min_width=100):
397
+ save_btn = gr.Button('Save Video')
398
+ with gr.Row():
399
+ with gr.Tab('Generated image gallery'):
400
+ gallery = gr.Gallery(
401
+ label='Generated images', show_label=False, elem_id='gallery', preview=True, rows=1, height=368,
402
+ ).style()
403
+
404
+ image.select(on_click, [state, seed, count, text, neg_text, speed_r, speed_x, speed_y, speed_z, t1, t2, t3, lr, guidance_weight,attn,threshold, early_stop], [image, depth_image, gallery, state])
405
+ text.submit(get_prompt, inputs=[text], outputs=[text])
406
+ neg_text.submit(get_prompt, inputs=[neg_text], outputs=[neg_text])
407
+ gen_btn.click(gen_img, inputs=[text, neg_text, state, seed], outputs=[image, depth_image, gallery, state])
408
+ reset_btn.click(on_reset, inputs=[state], outputs=[image, depth_image, gallery, state])
409
+ undo_btn.click(on_undo, inputs=[state], outputs=[image, depth_image, gallery, state])
410
+ save_btn.click(on_save, inputs=[state, video_name], outputs=[])
411
+
412
+ global num_inference_steps
413
+ global pipe
414
+ global intrinsic
415
+ global ddim_scheduler
416
+ global ddpm_scheduler
417
+ global device
418
+ global model_id
419
+ global torch_dtype
420
+
421
+ num_inference_steps = 50
422
+
423
+ device = args.device
424
+ model_id = args.model_id
425
+ ddim_scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
426
+ ddpm_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
427
+ torch_dtype=torch.float16 if 'cuda' in str(device) else torch.float32
428
+
429
+ pipe = DDIMBackward.from_pretrained(
430
+ model_id, scheduler=ddim_scheduler, torch_dtype=torch_dtype,
431
+ cache_dir='.', device=str(device), model_id=model_id, depth_model=args.depth_model,
432
+ ).to(str(device))
433
+
434
+ if 'cuda' in str(device):
435
+ pipe.enable_attention_slicing()
436
+ pipe.enable_xformers_memory_efficient_attention()
437
+
438
+ intrinsic = np.array([[1000, 0, 256],
439
+ [0, 1000., 256],
440
+ [0, 0, 1]]) # Example intrinsic matrix
441
+ return demo
442
+
443
+
444
+ if __name__ == '__main__':
445
+ import argparse
446
+ parser = argparse.ArgumentParser()
447
+ parser.add_argument('--device', default='cuda')
448
+ parser.add_argument('--model_id', default='stabilityai/stable-diffusion-2-1-base')
449
+ parser.add_argument('--depth_model', default='dpt_beit_large_512', choices=['dpt_beit_large_512', 'dpt_swin2_large_384'])
450
+ parser.add_argument('--share', action='store_true')
451
+ parser.add_argument('-p', '--port', type=int, default=None)
452
+ parser.add_argument('--ip', default=None)
453
+ args = parser.parse_args()
454
+ demo = main(args)
455
+ print('Successfully loaded, starting gradio demo')
456
+ demo.queue(concurrency_count=1, max_size=20).launch(share=args.share, server_name=args.ip, server_port=args.port)