customdiffusion360 commited on
Commit
8eb5f81
1 Parent(s): 43b6675

add instructions, do not load sdxl on original space

Browse files
Files changed (1) hide show
  1. app.py +48 -19
app.py CHANGED
@@ -28,7 +28,7 @@ def transform_mesh(mesh, transform, scale=1.0):
28
  return mesh
29
 
30
 
31
- def get_input_pose_fig():
32
  global curr_camera_dict
33
  global obj_filename
34
  global plane_trans
@@ -44,6 +44,11 @@ def get_input_pose_fig():
44
  ### plane
45
  rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device)
46
  plane = transform_mesh(plane, rotate_x)
 
 
 
 
 
47
  translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device)
48
  plane = transform_mesh(plane, translate_y)
49
 
@@ -171,7 +176,15 @@ def select_and_load_model(category, category_single_id):
171
 
172
  print("!!! model loaded")
173
 
174
- input_prompt = f"photo of a <new1> {category}"
 
 
 
 
 
 
 
 
175
  return "### Model loaded!", input_prompt
176
 
177
 
@@ -184,9 +197,15 @@ global base_model
184
  BASE_CONFIG = "custom-diffusion360/configs/train_co3d_concept.yaml"
185
  BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
186
 
187
- start_time = time.time()
188
- base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False)
189
- print(f"Time taken to load base model: {time.time() - start_time:.2f}s")
 
 
 
 
 
 
190
 
191
  global curr_camera_dict
192
  curr_camera_dict = {
@@ -280,7 +299,7 @@ def update_category_single_id(category):
280
  "scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786},
281
  "scene.aspectmode": "manual"
282
  }
283
- plane_trans = 0.16
284
 
285
  elif category == "teddybear":
286
  choices = ["31"]
@@ -299,7 +318,7 @@ def update_category_single_id(category):
299
  "scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052},
300
  "scene.aspectmode": "manual",
301
  }
302
- plane_trans = 0.23
303
 
304
  obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj"
305
  prev_camera_dict = copy.deepcopy(curr_camera_dict)
@@ -310,13 +329,6 @@ head = """
310
  <script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script>
311
  """
312
 
313
- ORIGINAL_SPACE_ID = 'customdiffusion360'
314
- SPACE_ID = os.getenv('SPACE_ID')
315
-
316
- SHARED_UI_WARNING = f'''## Attention - the demo requires at least 40GB VRAM for inference. Please clone this repository to run on your own machine.
317
- <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
318
- '''
319
-
320
  with gr.Blocks(head=head,
321
  css="style.css",
322
  js=scripts,
@@ -339,14 +351,21 @@ with gr.Blocks(head=head,
339
  <img src='https://img.shields.io/badge/Github-%23121011.svg'>
340
  </a>
341
  </div>
 
 
 
 
 
 
 
 
 
342
  <hr></hr>
343
  """,
344
  visible=True
345
  )
346
 
347
- if SPACE_ID == ORIGINAL_SPACE_ID:
348
- gr.Markdown(SHARED_UI_WARNING)
349
-
350
  with gr.Row():
351
  with gr.Column(min_width=150):
352
  gr.Markdown("## 1. SELECT CUSTOMIZED MODEL")
@@ -375,7 +394,7 @@ with gr.Blocks(head=head,
375
  ## TODO: track init_camera_dict and with js?
376
 
377
  ### visible elements
378
- input_prompt = gr.Textbox(value="photo of a <new1> car", label="Prompt", interactive=True)
379
  scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1)
380
  scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1)
381
  steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1)
@@ -389,8 +408,18 @@ with gr.Blocks(head=head,
389
  gr.Markdown("## 3. OUR OUTPUT")
390
  result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result")
391
 
 
 
 
 
 
 
 
 
 
 
392
  load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt])
393
- load_model_btn.click(get_input_pose_fig, [], [map])
394
 
395
  update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio)
396
  # check_pose_btn.click(check_curr_camera_dict, [], [input_pose])
 
28
  return mesh
29
 
30
 
31
+ def get_input_pose_fig(category=None):
32
  global curr_camera_dict
33
  global obj_filename
34
  global plane_trans
 
44
  ### plane
45
  rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device)
46
  plane = transform_mesh(plane, rotate_x)
47
+
48
+ if category == "teddybear":
49
+ rotate_teddy = RotateAxisAngle(angle=15.0, axis='X', device=device)
50
+ plane = transform_mesh(plane, rotate_teddy)
51
+
52
  translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device)
53
  plane = transform_mesh(plane, translate_y)
54
 
 
176
 
177
  print("!!! model loaded")
178
 
179
+ if category == "car":
180
+ input_prompt = "A <new1> car parked by a snowy mountain range"
181
+ elif category == "chair":
182
+ input_prompt = "A <new1> chair in a garden surrounded by flowers"
183
+ elif category == "motorcycle":
184
+ input_prompt = "A <new1> motorcycle beside a calm lake"
185
+ elif category == "teddybear":
186
+ input_prompt = "A <new1> teddy bear on the sand at the beach"
187
+
188
  return "### Model loaded!", input_prompt
189
 
190
 
 
197
  BASE_CONFIG = "custom-diffusion360/configs/train_co3d_concept.yaml"
198
  BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors"
199
 
200
+ base_model = None
201
+
202
+ ORIGINAL_SPACE_ID = "customdiffusion360/customdiffusion360"
203
+ SPACE_ID = os.getenv("SPACE_ID")
204
+
205
+ if SPACE_ID != ORIGINAL_SPACE_ID:
206
+ start_time = time.time()
207
+ base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False)
208
+ print(f"Time taken to load base model: {time.time() - start_time:.2f}s")
209
 
210
  global curr_camera_dict
211
  curr_camera_dict = {
 
299
  "scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786},
300
  "scene.aspectmode": "manual"
301
  }
302
+ plane_trans = 0.2
303
 
304
  elif category == "teddybear":
305
  choices = ["31"]
 
318
  "scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052},
319
  "scene.aspectmode": "manual",
320
  }
321
+ plane_trans = 0.3
322
 
323
  obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj"
324
  prev_camera_dict = copy.deepcopy(curr_camera_dict)
 
329
  <script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script>
330
  """
331
 
 
 
 
 
 
 
 
332
  with gr.Blocks(head=head,
333
  css="style.css",
334
  js=scripts,
 
351
  <img src='https://img.shields.io/badge/Github-%23121011.svg'>
352
  </a>
353
  </div>
354
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
355
+ <p>
356
+ This is a demo for <a href='https://github.com/customdiffusion360/custom-diffusion360'>Custom Diffusion 360</a>.
357
+ Please duplicate this space and upgrade the GPU to A10G Large in Settings to run the demo.
358
+ </p>
359
+ </div>
360
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
361
+ <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/customdiffusion360/customdiffusion360?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a>
362
+ </div>
363
  <hr></hr>
364
  """,
365
  visible=True
366
  )
367
 
368
+
 
 
369
  with gr.Row():
370
  with gr.Column(min_width=150):
371
  gr.Markdown("## 1. SELECT CUSTOMIZED MODEL")
 
394
  ## TODO: track init_camera_dict and with js?
395
 
396
  ### visible elements
397
+ input_prompt = gr.Textbox(value="A <new1> car parked by a snowy mountain range", label="Prompt", interactive=True)
398
  scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1)
399
  scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1)
400
  steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1)
 
408
  gr.Markdown("## 3. OUR OUTPUT")
409
  result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result")
410
 
411
+ gr.Markdown("### Camera Pose Controls:")
412
+ gr.Markdown("* Orbital rotation: Left-click and drag.")
413
+ gr.Markdown("* Zoom: Mouse wheel scroll.")
414
+ gr.Markdown("* Pan (translate the camera): Right-click and drag.")
415
+ gr.Markdown("* Tilt camera: Tilt mouse wheel left/right.")
416
+ gr.Markdown("* Reset to initial camera pose: Hover over the top right corner of the plot and click the camera icon.")
417
+ gr.Markdown("### Note:")
418
+ gr.Markdown("The models only work within a range of elevation angles and distances near the initial camera pose.")
419
+
420
+
421
  load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt])
422
+ load_model_btn.click(get_input_pose_fig, [category], [map])
423
 
424
  update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio)
425
  # check_pose_btn.click(check_curr_camera_dict, [], [input_pose])