toshas commited on
Commit
3f0925c
1 Parent(s): 4190f7f

add feedback collection functionality through HF OAuth

Browse files
Files changed (6) hide show
  1. README.md +2 -0
  2. app.py +292 -48
  3. extrude.py +3 -1
  4. flagging.py +387 -0
  5. marigold_depth_estimation_lcm.py +3 -1
  6. requirements.txt +5 -5
README.md CHANGED
@@ -10,6 +10,8 @@ pinned: true
10
  license: cc-by-sa-4.0
11
  models:
12
  - prs-eth/marigold-lcm-v1-0
 
 
13
  ---
14
 
15
  This is a demo of Marigold-LCM, the state-of-the-art depth estimator for images in the wild.
 
10
  license: cc-by-sa-4.0
11
  models:
12
  - prs-eth/marigold-lcm-v1-0
13
+ hf_oauth: true
14
+ hf_oauth_expiration_minutes: 43200
15
  ---
16
 
17
  This is a demo of Marigold-LCM, the state-of-the-art depth estimator for images in the wild.
app.py CHANGED
@@ -16,18 +16,19 @@
16
  # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
  # More information about the method can be found at https://marigoldmonodepth.github.io
18
  # --------------------------------------------------------------------------
19
-
20
 
21
  import functools
22
  import os
23
  import tempfile
 
24
  import zipfile
25
  from io import BytesIO
26
 
27
- import spaces
28
  import gradio as gr
29
  import imageio as imageio
30
  import numpy as np
 
31
  import torch as torch
32
  from PIL import Image
33
  from gradio_imageslider import ImageSlider
@@ -35,8 +36,13 @@ from huggingface_hub import login
35
  from tqdm import tqdm
36
 
37
  from extrude import extrude_depth_3d
 
38
  from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline
39
 
 
 
 
 
40
  default_seed = 2024
41
 
42
  default_image_denoise_steps = 4
@@ -64,6 +70,16 @@ default_bas_frame_thickness = 5
64
  default_bas_frame_near = 1
65
  default_bas_frame_far = 1
66
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def process_image(
69
  pipe,
@@ -72,6 +88,14 @@ def process_image(
72
  ensemble_size=default_image_ensemble_size,
73
  processing_res=default_image_processing_res,
74
  ):
 
 
 
 
 
 
 
 
75
  input_image = Image.open(path_input)
76
 
77
  pipe_out = pipe(
@@ -88,13 +112,6 @@ def process_image(
88
  depth_colored = pipe_out.depth_colored
89
  depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
90
 
91
- path_output_dir = tempfile.mkdtemp()
92
-
93
- name_base = os.path.splitext(os.path.basename(path_input))[0]
94
- path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
95
- path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
96
- path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
97
-
98
  np.save(path_out_fp32, depth_pred)
99
  Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
100
  depth_colored.save(path_out_vis)
@@ -116,9 +133,15 @@ def process_video(
116
  out_max_frames=default_video_out_max_frames,
117
  progress=gr.Progress(),
118
  ):
119
- path_output_dir = tempfile.mkdtemp()
 
 
 
120
 
121
- name_base = os.path.splitext(os.path.basename(path_input))[0]
 
 
 
122
  path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4")
123
  path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip")
124
 
@@ -218,12 +241,18 @@ def process_bas(
218
  frame_near=default_bas_frame_near,
219
  frame_far=default_bas_frame_far,
220
  ):
 
 
 
 
 
221
  if plane_near >= plane_far:
222
  raise gr.Error("NEAR plane must have a value smaller than the FAR plane")
223
 
224
- path_output_dir = tempfile.mkdtemp()
225
-
226
  name_base, name_ext = os.path.splitext(os.path.basename(path_input))
 
 
 
227
 
228
  input_image = Image.open(path_input)
229
 
@@ -267,9 +296,11 @@ def process_bas(
267
  path_glb, path_stl = extrude_depth_3d(
268
  image_rgb_new,
269
  image_depth_new,
270
- output_model_scale=size_longest_cm * 10
271
- if output_model_scale is None
272
- else output_model_scale,
 
 
273
  filter_size=filter_size,
274
  coef_near=plane_near,
275
  coef_far=plane_far,
@@ -288,17 +319,22 @@ def process_bas(
288
  256, filter_size, vertex_colors=False, scene_lights=True, output_model_scale=1
289
  )
290
  path_files_glb, path_files_stl = _process_3d(
291
- size_longest_px, filter_size, vertex_colors=True, scene_lights=False, prepare_for_3d_printing=True
 
 
 
 
292
  )
293
 
294
  return path_viewer_glb, [path_files_glb, path_files_stl]
295
 
296
 
297
- def run_demo_server(pipe):
298
  process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
299
- process_pipe_video = spaces.GPU(functools.partial(process_video, pipe), duration=120)
 
 
300
  process_pipe_bas = spaces.GPU(functools.partial(process_bas, pipe))
301
- os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
302
 
303
  gradio_theme = gr.themes.Default()
304
 
@@ -332,6 +368,9 @@ def run_demo_server(pipe):
332
  text-align: center;
333
  display: block;
334
  }
 
 
 
335
  """,
336
  head="""
337
  <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
@@ -343,35 +382,70 @@ def run_demo_server(pipe):
343
  </script>
344
  """,
345
  ) as demo:
 
 
 
 
 
 
 
346
  gr.Markdown(
347
  """
348
  # Marigold-LCM Depth Estimation
349
  <p align="center">
350
- <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
 
351
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
352
  </a>
353
- <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
 
354
  <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
355
  </a>
356
- <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
357
- <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
 
 
358
  </a>
359
- <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
 
360
  <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
361
  </a>
362
  </p>
363
  <p align="justify">
364
- Marigold-LCM is the fast version of Marigold, the state-of-the-art depth estimator for images in the wild.
365
- It combines the power of the original Marigold 10-step estimator and the Latent Consistency Models, delivering high-quality results in as little as <b>one step</b>.
366
- We provide three functions in this demo: Image, Video, and Bas-relief 3D processing — <b>see the tabs below</b>.
367
- Upload your content into the <b>first</b> pane, or click any of the <b>examples</b> below.
368
- Wait a second (for images and 3D) or a minute (for videos), and interact with the result in the <b>second</b> pane.
369
- To avoid queuing, fork the demo into your profile.
370
- <a href="https://huggingface.co/spaces/prs-eth/marigold">The original Marigold demo is also available</a>.
 
 
 
371
  </p>
372
  """
373
  )
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
  with gr.Tabs(elem_classes=["tabs"]):
376
  with gr.Tab("Image"):
377
  with gr.Row():
@@ -423,6 +497,42 @@ def run_demo_server(pipe):
423
  elem_id="download",
424
  interactive=False,
425
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  gr.Examples(
427
  fn=process_pipe_image,
428
  examples=[
@@ -502,7 +612,8 @@ def run_demo_server(pipe):
502
  """
503
  <p align="justify">
504
  This part of the demo uses Marigold-LCM to create a bas-relief model.
505
- The models are watertight, with correct normals, and exported in the STL format, which makes them <b>3D-printable</b>.
 
506
  </p>
507
  """,
508
  )
@@ -513,7 +624,9 @@ def run_demo_server(pipe):
513
  type="filepath",
514
  )
515
  with gr.Row():
516
- bas_submit_btn = gr.Button(value="Create 3D", variant="primary")
 
 
517
  bas_reset_btn = gr.Button(value="Reset")
518
  with gr.Accordion("3D printing demo: Main options", open=True):
519
  bas_plane_near = gr.Slider(
@@ -537,7 +650,9 @@ def run_demo_server(pipe):
537
  step=1,
538
  value=default_bas_embossing,
539
  )
540
- with gr.Accordion("3D printing demo: Advanced options", open=False):
 
 
541
  bas_denoise_steps = gr.Slider(
542
  label="Number of denoising steps",
543
  minimum=1,
@@ -682,17 +797,66 @@ def run_demo_server(pipe):
682
  cache_examples=True,
683
  )
684
 
685
- image_submit_btn.click(
686
- fn=process_pipe_image,
687
- inputs=[
688
- image_input,
689
- image_denoise_steps,
690
- image_ensemble_size,
691
- image_processing_res,
692
- ],
693
- outputs=[image_output_slider, image_output_files],
694
- concurrency_limit=1,
695
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696
 
697
  image_reset_btn.click(
698
  fn=lambda: (
@@ -712,9 +876,73 @@ def run_demo_server(pipe):
712
  image_denoise_steps,
713
  image_processing_res,
714
  ],
715
- concurrency_limit=1,
716
  )
717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
718
  video_submit_btn.click(
719
  fn=process_pipe_video,
720
  inputs=[video_input],
@@ -729,6 +957,8 @@ def run_demo_server(pipe):
729
  concurrency_limit=1,
730
  )
731
 
 
 
732
  bas_submit_btn.click(
733
  fn=process_pipe_bas,
734
  inputs=[
@@ -791,6 +1021,8 @@ def run_demo_server(pipe):
791
  concurrency_limit=1,
792
  )
793
 
 
 
794
  demo.queue(
795
  api_open=False,
796
  ).launch(
@@ -801,6 +1033,7 @@ def run_demo_server(pipe):
801
 
802
  def main():
803
  CHECKPOINT = "prs-eth/marigold-lcm-v1-0"
 
804
 
805
  if "HF_TOKEN_LOGIN" in os.environ:
806
  login(token=os.environ["HF_TOKEN_LOGIN"])
@@ -816,7 +1049,18 @@ def main():
816
  pass # run without xformers
817
 
818
  pipe = pipe.to(device)
819
- run_demo_server(pipe)
 
 
 
 
 
 
 
 
 
 
 
820
 
821
 
822
  if __name__ == "__main__":
 
16
  # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
  # More information about the method can be found at https://marigoldmonodepth.github.io
18
  # --------------------------------------------------------------------------
19
+ from __future__ import annotations
20
 
21
  import functools
22
  import os
23
  import tempfile
24
+ import warnings
25
  import zipfile
26
  from io import BytesIO
27
 
 
28
  import gradio as gr
29
  import imageio as imageio
30
  import numpy as np
31
+ import spaces
32
  import torch as torch
33
  from PIL import Image
34
  from gradio_imageslider import ImageSlider
 
36
  from tqdm import tqdm
37
 
38
  from extrude import extrude_depth_3d
39
+ from flagging import FlagMethod, HuggingFaceDatasetSaver
40
  from marigold_depth_estimation_lcm import MarigoldDepthConsistencyPipeline
41
 
42
+ warnings.filterwarnings(
43
+ "ignore", message=".*LoginButton created outside of a Blocks context.*"
44
+ )
45
+
46
  default_seed = 2024
47
 
48
  default_image_denoise_steps = 4
 
70
  default_bas_frame_near = 1
71
  default_bas_frame_far = 1
72
 
73
+ default_share_always_show_hf_logout_btn = True
74
+ default_share_always_show_accordion = False
75
+
76
+
77
+ def process_image_check(path_input):
78
+ if path_input is None:
79
+ raise gr.Error(
80
+ "Missing image in the first pane: upload a file or use one from the gallery below."
81
+ )
82
+
83
 
84
  def process_image(
85
  pipe,
 
88
  ensemble_size=default_image_ensemble_size,
89
  processing_res=default_image_processing_res,
90
  ):
91
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
92
+ print(f"Processing image {name_base}{name_ext}")
93
+
94
+ path_output_dir = tempfile.mkdtemp()
95
+ path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
96
+ path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
97
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
98
+
99
  input_image = Image.open(path_input)
100
 
101
  pipe_out = pipe(
 
112
  depth_colored = pipe_out.depth_colored
113
  depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
114
 
 
 
 
 
 
 
 
115
  np.save(path_out_fp32, depth_pred)
116
  Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
117
  depth_colored.save(path_out_vis)
 
133
  out_max_frames=default_video_out_max_frames,
134
  progress=gr.Progress(),
135
  ):
136
+ if path_input is None:
137
+ raise gr.Error(
138
+ "Missing video in the first pane: upload a file or use one from the gallery below."
139
+ )
140
 
141
+ name_base, name_ext = os.path.splitext(os.path.basename(path_input))
142
+ print(f"Processing video {name_base}{name_ext}")
143
+
144
+ path_output_dir = tempfile.mkdtemp()
145
  path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4")
146
  path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip")
147
 
 
241
  frame_near=default_bas_frame_near,
242
  frame_far=default_bas_frame_far,
243
  ):
244
+ if path_input is None:
245
+ raise gr.Error(
246
+ "Missing image in the first pane: upload a file or use one from the gallery below."
247
+ )
248
+
249
  if plane_near >= plane_far:
250
  raise gr.Error("NEAR plane must have a value smaller than the FAR plane")
251
 
 
 
252
  name_base, name_ext = os.path.splitext(os.path.basename(path_input))
253
+ print(f"Processing bas-relief {name_base}{name_ext}")
254
+
255
+ path_output_dir = tempfile.mkdtemp()
256
 
257
  input_image = Image.open(path_input)
258
 
 
296
  path_glb, path_stl = extrude_depth_3d(
297
  image_rgb_new,
298
  image_depth_new,
299
+ output_model_scale=(
300
+ size_longest_cm * 10
301
+ if output_model_scale is None
302
+ else output_model_scale
303
+ ),
304
  filter_size=filter_size,
305
  coef_near=plane_near,
306
  coef_far=plane_far,
 
319
  256, filter_size, vertex_colors=False, scene_lights=True, output_model_scale=1
320
  )
321
  path_files_glb, path_files_stl = _process_3d(
322
+ size_longest_px,
323
+ filter_size,
324
+ vertex_colors=True,
325
+ scene_lights=False,
326
+ prepare_for_3d_printing=True,
327
  )
328
 
329
  return path_viewer_glb, [path_files_glb, path_files_stl]
330
 
331
 
332
+ def run_demo_server(pipe, hf_writer=None):
333
  process_pipe_image = spaces.GPU(functools.partial(process_image, pipe))
334
+ process_pipe_video = spaces.GPU(
335
+ functools.partial(process_video, pipe), duration=120
336
+ )
337
  process_pipe_bas = spaces.GPU(functools.partial(process_bas, pipe))
 
338
 
339
  gradio_theme = gr.themes.Default()
340
 
 
368
  text-align: center;
369
  display: block;
370
  }
371
+ .md_feedback li {
372
+ margin-bottom: 0px !important;
373
+ }
374
  """,
375
  head="""
376
  <script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script>
 
382
  </script>
383
  """,
384
  ) as demo:
385
+ if hf_writer is not None:
386
+ print("Creating login button")
387
+ share_login_btn = gr.LoginButton(size="sm", scale=1, render=False)
388
+ print("Created login button")
389
+ share_login_btn.activate()
390
+ print("Activated login button")
391
+
392
  gr.Markdown(
393
  """
394
  # Marigold-LCM Depth Estimation
395
  <p align="center">
396
+ <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer"
397
+ style="display: inline-block;">
398
  <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
399
  </a>
400
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer"
401
+ style="display: inline-block;">
402
  <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
403
  </a>
404
+ <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer"
405
+ style="display: inline-block;">
406
+ <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C"
407
+ alt="badge-github-stars">
408
  </a>
409
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer"
410
+ style="display: inline-block;">
411
  <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
412
  </a>
413
  </p>
414
  <p align="justify">
415
+ Marigold-LCM is the fast version of Marigold, the state-of-the-art depth estimator for images in the
416
+ wild. It combines the power of the original Marigold 10-step estimator and the Latent Consistency
417
+ Models, delivering high-quality results in as little as <b>one step</b>. We provide three functions
418
+ in this demo: Image, Video, and Bas-relief 3D processing <b>see the tabs below</b>. Upload your
419
+ content into the <b>first</b> pane, or click any of the <b>examples</b> below. Wait a second (for
420
+ images and 3D) or a minute (for videos), and interact with the result in the <b>second</b> pane. To
421
+ avoid queuing, fork the demo into your profile.
422
+ <a href="https://huggingface.co/spaces/prs-eth/marigold">
423
+ The original Marigold demo is also available
424
+ </a>.
425
  </p>
426
  """
427
  )
428
 
429
+ def get_share_instructions(is_full):
430
+ out = (
431
+ "### Help us improve Marigold! If the output is not what you expected, "
432
+ "you can help us by sharing it with us privately.\n"
433
+ )
434
+ if is_full:
435
+ out += (
436
+ "1. Sign into your Hugging Face account using the button below.\n"
437
+ "1. Signing in may reset the demo and results; in that case, process the image again.\n"
438
+ )
439
+ out += "1. Review and agree to the terms of usage and enter an optional message to us.\n"
440
+ out += "1. Click the 'Share' button to submit the image to us privately.\n"
441
+ return out
442
+
443
+ def get_share_conditioned_on_login(profile: gr.OAuthProfile | None):
444
+ state_logged_out = profile is None
445
+ return get_share_instructions(is_full=state_logged_out), gr.Button(
446
+ visible=(state_logged_out or default_share_always_show_hf_logout_btn)
447
+ )
448
+
449
  with gr.Tabs(elem_classes=["tabs"]):
450
  with gr.Tab("Image"):
451
  with gr.Row():
 
497
  elem_id="download",
498
  interactive=False,
499
  )
500
+
501
+ if hf_writer is not None:
502
+ with gr.Accordion(
503
+ "Feedback",
504
+ open=False,
505
+ visible=default_share_always_show_accordion,
506
+ ) as share_box:
507
+ share_instructions = gr.Markdown(
508
+ get_share_instructions(is_full=True),
509
+ elem_classes="md_feedback",
510
+ )
511
+ share_transfer_of_rights = gr.Checkbox(
512
+ label="(Optional) I own or hold necessary rights to the submitted image. By "
513
+ "checking this box, I grant an irrevocable, non-exclusive, transferable, "
514
+ "royalty-free, worldwide license to use the uploaded image, including for "
515
+ "publishing, reproducing, and model training. [transfer_of_rights]",
516
+ scale=1,
517
+ )
518
+ share_content_is_legal = gr.Checkbox(
519
+ label="By checking this box, I acknowledge that my uploaded content is legal and "
520
+ "safe, and that I am solely responsible for ensuring it complies with all "
521
+ "applicable laws and regulations. Additionally, I am aware that my Hugging Face "
522
+ "username is collected. [content_is_legal]",
523
+ scale=1,
524
+ )
525
+ share_reason = gr.Textbox(
526
+ label="(Optional) Reason for feedback",
527
+ max_lines=1,
528
+ interactive=True,
529
+ )
530
+ with gr.Row():
531
+ share_login_btn.render()
532
+ share_share_btn = gr.Button(
533
+ "Share", variant="stop", scale=1
534
+ )
535
+
536
  gr.Examples(
537
  fn=process_pipe_image,
538
  examples=[
 
612
  """
613
  <p align="justify">
614
  This part of the demo uses Marigold-LCM to create a bas-relief model.
615
+ The models are watertight, with correct normals, and exported in the STL format, which makes
616
+ them <b>3D-printable</b>.
617
  </p>
618
  """,
619
  )
 
624
  type="filepath",
625
  )
626
  with gr.Row():
627
+ bas_submit_btn = gr.Button(
628
+ value="Create 3D", variant="primary"
629
+ )
630
  bas_reset_btn = gr.Button(value="Reset")
631
  with gr.Accordion("3D printing demo: Main options", open=True):
632
  bas_plane_near = gr.Slider(
 
650
  step=1,
651
  value=default_bas_embossing,
652
  )
653
+ with gr.Accordion(
654
+ "3D printing demo: Advanced options", open=False
655
+ ):
656
  bas_denoise_steps = gr.Slider(
657
  label="Number of denoising steps",
658
  minimum=1,
 
797
  cache_examples=True,
798
  )
799
 
800
+ ### Image tab
801
+
802
+ if hf_writer is not None:
803
+ image_submit_btn.click(
804
+ fn=process_image_check,
805
+ inputs=image_input,
806
+ outputs=None,
807
+ preprocess=False,
808
+ queue=False,
809
+ ).success(
810
+ get_share_conditioned_on_login,
811
+ None,
812
+ [share_instructions, share_login_btn],
813
+ queue=False,
814
+ ).then(
815
+ lambda: (
816
+ gr.Button(value="Share", interactive=True),
817
+ gr.Accordion(visible=True),
818
+ False,
819
+ False,
820
+ "",
821
+ ),
822
+ None,
823
+ [
824
+ share_share_btn,
825
+ share_box,
826
+ share_transfer_of_rights,
827
+ share_content_is_legal,
828
+ share_reason,
829
+ ],
830
+ queue=False,
831
+ ).then(
832
+ fn=process_pipe_image,
833
+ inputs=[
834
+ image_input,
835
+ image_denoise_steps,
836
+ image_ensemble_size,
837
+ image_processing_res,
838
+ ],
839
+ outputs=[image_output_slider, image_output_files],
840
+ concurrency_limit=1,
841
+ )
842
+ else:
843
+ image_submit_btn.click(
844
+ fn=process_image_check,
845
+ inputs=image_input,
846
+ outputs=None,
847
+ preprocess=False,
848
+ queue=False,
849
+ ).success(
850
+ fn=process_pipe_image,
851
+ inputs=[
852
+ image_input,
853
+ image_denoise_steps,
854
+ image_ensemble_size,
855
+ image_processing_res,
856
+ ],
857
+ outputs=[image_output_slider, image_output_files],
858
+ concurrency_limit=1,
859
+ )
860
 
861
  image_reset_btn.click(
862
  fn=lambda: (
 
876
  image_denoise_steps,
877
  image_processing_res,
878
  ],
879
+ queue=False,
880
  )
881
 
882
+ if hf_writer is not None:
883
+ image_reset_btn.click(
884
+ fn=lambda: (
885
+ gr.Button(value="Share", interactive=True),
886
+ gr.Accordion(visible=default_share_always_show_accordion),
887
+ ),
888
+ inputs=[],
889
+ outputs=[
890
+ share_share_btn,
891
+ share_box,
892
+ ],
893
+ queue=False,
894
+ )
895
+
896
+ ### Share functionality
897
+
898
+ if hf_writer is not None:
899
+ share_components = [
900
+ image_input,
901
+ image_denoise_steps,
902
+ image_ensemble_size,
903
+ image_processing_res,
904
+ image_output_slider,
905
+ share_content_is_legal,
906
+ share_transfer_of_rights,
907
+ share_reason,
908
+ ]
909
+
910
+ hf_writer.setup(share_components, "shared_data")
911
+ share_callback = FlagMethod(hf_writer, "Share", "", visual_feedback=True)
912
+
913
+ def share_precheck(
914
+ hf_content_is_legal,
915
+ image_output_slider,
916
+ profile: gr.OAuthProfile | None,
917
+ ):
918
+ if profile is None:
919
+ raise gr.Error(
920
+ "Log into the Space with your Hugging Face account first."
921
+ )
922
+ if image_output_slider is None or image_output_slider[0] is None:
923
+ raise gr.Error("No output detected; process the image first.")
924
+ if not hf_content_is_legal:
925
+ raise gr.Error(
926
+ "You must consent that the uploaded content is legal."
927
+ )
928
+ return gr.Button(value="Sharing in progress", interactive=False)
929
+
930
+ share_share_btn.click(
931
+ share_precheck,
932
+ [share_content_is_legal, image_output_slider],
933
+ share_share_btn,
934
+ preprocess=False,
935
+ queue=False,
936
+ ).success(
937
+ share_callback,
938
+ inputs=share_components,
939
+ outputs=share_share_btn,
940
+ preprocess=False,
941
+ queue=False,
942
+ )
943
+
944
+ ### Video tab
945
+
946
  video_submit_btn.click(
947
  fn=process_pipe_video,
948
  inputs=[video_input],
 
957
  concurrency_limit=1,
958
  )
959
 
960
+ ### Bas-relief tab
961
+
962
  bas_submit_btn.click(
963
  fn=process_pipe_bas,
964
  inputs=[
 
1021
  concurrency_limit=1,
1022
  )
1023
 
1024
+ ### Server launch
1025
+
1026
  demo.queue(
1027
  api_open=False,
1028
  ).launch(
 
1033
 
1034
  def main():
1035
  CHECKPOINT = "prs-eth/marigold-lcm-v1-0"
1036
+ CROWD_DATA = "crowddata-marigold-lcm-v1-0-space-v1-0"
1037
 
1038
  if "HF_TOKEN_LOGIN" in os.environ:
1039
  login(token=os.environ["HF_TOKEN_LOGIN"])
 
1049
  pass # run without xformers
1050
 
1051
  pipe = pipe.to(device)
1052
+
1053
+ hf_writer = None
1054
+ if "HF_TOKEN_LOGIN" in os.environ:
1055
+ hf_writer = HuggingFaceDatasetSaver(
1056
+ os.getenv("HF_TOKEN_LOGIN"),
1057
+ CROWD_DATA,
1058
+ private=True,
1059
+ info_filename="dataset_info.json",
1060
+ separate_dirs=True,
1061
+ )
1062
+
1063
+ run_demo_server(pipe, hf_writer)
1064
 
1065
 
1066
  if __name__ == "__main__":
extrude.py CHANGED
@@ -336,7 +336,9 @@ def extrude_depth_3d(
336
  mesh.apply_scale(scaling_factor)
337
 
338
  if prepare_for_3d_printing:
339
- rotation_mat = trimesh.transformations.rotation_matrix(np.radians(90), [-1, 0, 0])
 
 
340
  mesh.apply_transform(rotation_mat)
341
 
342
  path_out_base = os.path.splitext(path_depth)[0].replace("_16bit", "")
 
336
  mesh.apply_scale(scaling_factor)
337
 
338
  if prepare_for_3d_printing:
339
+ rotation_mat = trimesh.transformations.rotation_matrix(
340
+ np.radians(90), [-1, 0, 0]
341
+ )
342
  mesh.apply_transform(rotation_mat)
343
 
344
  path_out_base = os.path.splitext(path_depth)[0].replace("_16bit", "")
flagging.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import json
5
+ import time
6
+ import uuid
7
+ from abc import ABC, abstractmethod
8
+ from collections import OrderedDict
9
+ from datetime import datetime, timezone
10
+ from pathlib import Path
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ import filelock
14
+ import huggingface_hub
15
+ from gradio_client import utils as client_utils
16
+ from gradio_client.documentation import document
17
+
18
+ import gradio as gr
19
+ from gradio import utils
20
+
21
+ if TYPE_CHECKING:
22
+ from gradio.components import Component
23
+
24
+
25
+ class FlaggingCallback(ABC):
26
+ """
27
+ An abstract class for defining the methods that any FlaggingCallback should have.
28
+ """
29
+
30
+ @abstractmethod
31
+ def setup(self, components: list[Component], flagging_dir: str):
32
+ """
33
+ This method should be overridden and ensure that everything is set up correctly for flag().
34
+ This method gets called once at the beginning of the Interface.launch() method.
35
+ Parameters:
36
+ components: Set of components that will provide flagged data.
37
+ flagging_dir: A string, typically containing the path to the directory where the flagging file should be stored (provided as an argument to Interface.__init__()).
38
+ """
39
+ pass
40
+
41
+ @abstractmethod
42
+ def flag(
43
+ self,
44
+ flag_data: list[Any],
45
+ flag_option: str = "",
46
+ username: str | None = None,
47
+ ) -> int:
48
+ """
49
+ This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
50
+ This gets called every time the <flag> button is pressed.
51
+ Parameters:
52
+ interface: The Interface object that is being used to launch the flagging interface.
53
+ flag_data: The data to be flagged.
54
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
55
+ username (optional): The username of the user that is flagging the data, if logged in.
56
+ Returns:
57
+ (int) The total number of samples that have been flagged.
58
+ """
59
+ pass
60
+
61
+
62
+ @document()
63
+ class HuggingFaceDatasetSaver(FlaggingCallback):
64
+ """
65
+ A callback that saves each flagged sample (both the input and output data) to a HuggingFace dataset.
66
+
67
+ Example:
68
+ import gradio as gr
69
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
70
+ def image_classifier(inp):
71
+ return {'cat': 0.3, 'dog': 0.7}
72
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
73
+ allow_flagging="manual", flagging_callback=hf_writer)
74
+ Guides: using-flagging
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ hf_token: str,
80
+ dataset_name: str,
81
+ private: bool = False,
82
+ info_filename: str = "dataset_info.json",
83
+ separate_dirs: bool = False,
84
+ ):
85
+ """
86
+ Parameters:
87
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset (defaults to the registered one).
88
+ dataset_name: The repo_id of the dataset to save the data to, e.g. "image-classifier-1" or "username/image-classifier-1".
89
+ private: Whether the dataset should be private (defaults to False).
90
+ info_filename: The name of the file to save the dataset info (defaults to "dataset_infos.json").
91
+ separate_dirs: If True, each flagged item will be saved in a separate directory. This makes the flagging more robust to concurrent editing, but may be less convenient to use.
92
+ """
93
+ self.hf_token = hf_token
94
+ self.dataset_id = dataset_name # TODO: rename parameter (but ensure backward compatibility somehow)
95
+ self.dataset_private = private
96
+ self.info_filename = info_filename
97
+ self.separate_dirs = separate_dirs
98
+
99
+ def setup(self, components: list[Component], flagging_dir: str):
100
+ """
101
+ Params:
102
+ flagging_dir (str): local directory where the dataset is cloned,
103
+ updated, and pushed from.
104
+ """
105
+ # Setup dataset on the Hub
106
+ self.dataset_id = huggingface_hub.create_repo(
107
+ repo_id=self.dataset_id,
108
+ token=self.hf_token,
109
+ private=self.dataset_private,
110
+ repo_type="dataset",
111
+ exist_ok=True,
112
+ ).repo_id
113
+ path_glob = "**/*.jsonl" if self.separate_dirs else "data.csv"
114
+ huggingface_hub.metadata_update(
115
+ repo_id=self.dataset_id,
116
+ repo_type="dataset",
117
+ metadata={
118
+ "configs": [
119
+ {
120
+ "config_name": "default",
121
+ "data_files": [{"split": "train", "path": path_glob}],
122
+ }
123
+ ]
124
+ },
125
+ overwrite=True,
126
+ token=self.hf_token,
127
+ )
128
+
129
+ # Setup flagging dir
130
+ self.components = components
131
+ self.dataset_dir = (
132
+ Path(flagging_dir).absolute() / self.dataset_id.split("/")[-1]
133
+ )
134
+ self.dataset_dir.mkdir(parents=True, exist_ok=True)
135
+ self.infos_file = self.dataset_dir / self.info_filename
136
+
137
+ # Download remote files to local
138
+ remote_files = [self.info_filename]
139
+ if not self.separate_dirs:
140
+ # No separate dirs => means all data is in the same CSV file => download it to get its current content
141
+ remote_files.append("data.csv")
142
+
143
+ for filename in remote_files:
144
+ try:
145
+ huggingface_hub.hf_hub_download(
146
+ repo_id=self.dataset_id,
147
+ repo_type="dataset",
148
+ filename=filename,
149
+ local_dir=self.dataset_dir,
150
+ token=self.hf_token,
151
+ )
152
+ except huggingface_hub.utils.EntryNotFoundError:
153
+ pass
154
+
155
+ def flag(
156
+ self,
157
+ flag_data: list[Any],
158
+ flag_option: str = "",
159
+ username: str | None = None,
160
+ ) -> int:
161
+ if self.separate_dirs:
162
+ # JSONL files to support dataset preview on the Hub
163
+ current_utc_time = datetime.now(timezone.utc)
164
+ iso_format_without_microseconds = current_utc_time.strftime(
165
+ "%Y-%m-%dT%H:%M:%S"
166
+ )
167
+ milliseconds = int(current_utc_time.microsecond / 1000)
168
+ unique_id = f"{iso_format_without_microseconds}.{milliseconds:03}Z"
169
+ if username not in (None, ""):
170
+ unique_id += f"_U_{username}"
171
+ else:
172
+ unique_id += f"_{str(uuid.uuid4())[:8]}"
173
+ components_dir = self.dataset_dir / unique_id
174
+ data_file = components_dir / "metadata.jsonl"
175
+ path_in_repo = unique_id # upload in sub folder (safer for concurrency)
176
+ else:
177
+ # Unique CSV file
178
+ components_dir = self.dataset_dir
179
+ data_file = components_dir / "data.csv"
180
+ path_in_repo = None # upload at root level
181
+
182
+ return self._flag_in_dir(
183
+ data_file=data_file,
184
+ components_dir=components_dir,
185
+ path_in_repo=path_in_repo,
186
+ flag_data=flag_data,
187
+ flag_option=flag_option,
188
+ username=username or "",
189
+ )
190
+
191
+ def _flag_in_dir(
192
+ self,
193
+ data_file: Path,
194
+ components_dir: Path,
195
+ path_in_repo: str | None,
196
+ flag_data: list[Any],
197
+ flag_option: str = "",
198
+ username: str = "",
199
+ ) -> int:
200
+ # Deserialize components (write images/audio to files)
201
+ features, row = self._deserialize_components(
202
+ components_dir, flag_data, flag_option, username
203
+ )
204
+
205
+ # Write generic info to dataset_infos.json + upload
206
+ with filelock.FileLock(str(self.infos_file) + ".lock"):
207
+ if not self.infos_file.exists():
208
+ self.infos_file.write_text(
209
+ json.dumps({"flagged": {"features": features}})
210
+ )
211
+
212
+ huggingface_hub.upload_file(
213
+ repo_id=self.dataset_id,
214
+ repo_type="dataset",
215
+ token=self.hf_token,
216
+ path_in_repo=self.infos_file.name,
217
+ path_or_fileobj=self.infos_file,
218
+ )
219
+
220
+ headers = list(features.keys())
221
+
222
+ if not self.separate_dirs:
223
+ with filelock.FileLock(components_dir / ".lock"):
224
+ sample_nb = self._save_as_csv(data_file, headers=headers, row=row)
225
+ sample_name = str(sample_nb)
226
+ huggingface_hub.upload_folder(
227
+ repo_id=self.dataset_id,
228
+ repo_type="dataset",
229
+ commit_message=f"Flagged sample #{sample_name}",
230
+ path_in_repo=path_in_repo,
231
+ ignore_patterns="*.lock",
232
+ folder_path=components_dir,
233
+ token=self.hf_token,
234
+ )
235
+ else:
236
+ sample_name = self._save_as_jsonl(data_file, headers=headers, row=row)
237
+ sample_nb = len(
238
+ [path for path in self.dataset_dir.iterdir() if path.is_dir()]
239
+ )
240
+ huggingface_hub.upload_folder(
241
+ repo_id=self.dataset_id,
242
+ repo_type="dataset",
243
+ commit_message=f"Flagged sample #{sample_name}",
244
+ path_in_repo=path_in_repo,
245
+ ignore_patterns="*.lock",
246
+ folder_path=components_dir,
247
+ token=self.hf_token,
248
+ )
249
+
250
+ return sample_nb
251
+
252
+ @staticmethod
253
+ def _save_as_csv(data_file: Path, headers: list[str], row: list[Any]) -> int:
254
+ """Save data as CSV and return the sample name (row number)."""
255
+ is_new = not data_file.exists()
256
+
257
+ with data_file.open("a", newline="", encoding="utf-8") as csvfile:
258
+ writer = csv.writer(csvfile)
259
+
260
+ # Write CSV headers if new file
261
+ if is_new:
262
+ writer.writerow(utils.sanitize_list_for_csv(headers))
263
+
264
+ # Write CSV row for flagged sample
265
+ writer.writerow(utils.sanitize_list_for_csv(row))
266
+
267
+ with data_file.open(encoding="utf-8") as csvfile:
268
+ return sum(1 for _ in csv.reader(csvfile)) - 1
269
+
270
+ @staticmethod
271
+ def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
272
+ """Save data as JSONL and return the sample name (uuid)."""
273
+ Path.mkdir(data_file.parent, parents=True, exist_ok=True)
274
+ with open(data_file, "w") as f:
275
+ json.dump(dict(zip(headers, row)), f)
276
+ return data_file.parent.name
277
+
278
+ def _deserialize_components(
279
+ self,
280
+ data_dir: Path,
281
+ flag_data: list[Any],
282
+ flag_option: str = "",
283
+ username: str = "",
284
+ ) -> tuple[dict[Any, Any], list[Any]]:
285
+ """Deserialize components and return the corresponding row for the flagged sample.
286
+
287
+ Images/audio are saved to disk as individual files.
288
+ """
289
+ # Components that can have a preview on dataset repos
290
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
291
+
292
+ # Generate the row corresponding to the flagged sample
293
+ features = OrderedDict()
294
+ row = []
295
+ for component, sample in zip(self.components, flag_data):
296
+ # Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
297
+ label = component.label or ""
298
+ save_dir = data_dir / client_utils.strip_invalid_filename_characters(label)
299
+ save_dir.mkdir(exist_ok=True, parents=True)
300
+ deserialized = component.flag(sample, save_dir)
301
+
302
+ # Base component .flag method returns JSON; extract path from it when it is FileData
303
+ if component.data_model:
304
+ data = component.data_model.from_json(json.loads(deserialized))
305
+ if component.data_model == gr.data_classes.FileData:
306
+ deserialized = data.path
307
+
308
+ # Add deserialized object to row
309
+ features[label] = {"dtype": "string", "_type": "Value"}
310
+ try:
311
+ deserialized_path = Path(deserialized)
312
+ if not deserialized_path.exists():
313
+ raise FileNotFoundError(f"File {deserialized} not found")
314
+ row.append(str(deserialized_path.relative_to(self.dataset_dir)))
315
+ except (FileNotFoundError, TypeError, ValueError):
316
+ deserialized = "" if deserialized is None else str(deserialized)
317
+ row.append(deserialized)
318
+
319
+ # If component is eligible for a preview, add the URL of the file
320
+ # Be mindful that images and audio can be None
321
+ if isinstance(component, tuple(file_preview_types)): # type: ignore
322
+ for _component, _type in file_preview_types.items():
323
+ if isinstance(component, _component):
324
+ features[label + " file"] = {"_type": _type}
325
+ break
326
+ if deserialized:
327
+ path_in_repo = str( # returned filepath is absolute, we want it relative to compute URL
328
+ Path(deserialized).relative_to(self.dataset_dir)
329
+ ).replace(
330
+ "\\", "/"
331
+ )
332
+ row.append(
333
+ huggingface_hub.hf_hub_url(
334
+ repo_id=self.dataset_id,
335
+ filename=path_in_repo,
336
+ repo_type="dataset",
337
+ )
338
+ )
339
+ else:
340
+ row.append("")
341
+ features["flag"] = {"dtype": "string", "_type": "Value"}
342
+ features["username"] = {"dtype": "string", "_type": "Value"}
343
+ row.append(flag_option)
344
+ row.append(username)
345
+ return features, row
346
+
347
+
348
+ class FlagMethod:
349
+ """
350
+ Helper class that contains the flagging options and calls the flagging method. Also
351
+ provides visual feedback to the user when flag is clicked.
352
+ """
353
+
354
+ def __init__(
355
+ self,
356
+ flagging_callback: FlaggingCallback,
357
+ label: str,
358
+ value: str,
359
+ visual_feedback: bool = True,
360
+ ):
361
+ self.flagging_callback = flagging_callback
362
+ self.label = label
363
+ self.value = value
364
+ self.__name__ = "Flag"
365
+ self.visual_feedback = visual_feedback
366
+
367
+ def __call__(
368
+ self,
369
+ request: gr.Request,
370
+ profile: gr.OAuthProfile | None,
371
+ *flag_data,
372
+ ):
373
+ username = None
374
+ if profile is not None:
375
+ username = profile.username
376
+ try:
377
+ self.flagging_callback.flag(
378
+ list(flag_data), flag_option=self.value, username=username
379
+ )
380
+ except Exception as e:
381
+ print(f"Error while sharing: {e}")
382
+ if self.visual_feedback:
383
+ return gr.Button(value="Sharing error", interactive=False)
384
+ if not self.visual_feedback:
385
+ return
386
+ time.sleep(0.8) # to provide enough time for the user to observe button change
387
+ return gr.Button(value="Sharing complete", interactive=False)
marigold_depth_estimation_lcm.py CHANGED
@@ -391,7 +391,9 @@ class MarigoldDepthConsistencyPipeline(DiffusionPipeline):
391
  ).sample # [B, 4, h, w]
392
 
393
  # compute the previous noisy sample x_t -> x_t-1
394
- depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rng).prev_sample
 
 
395
 
396
  depth = self._decode_depth(depth_latent)
397
 
 
391
  ).sample # [B, 4, h, w]
392
 
393
  # compute the previous noisy sample x_t -> x_t-1
394
+ depth_latent = self.scheduler.step(
395
+ noise_pred, t, depth_latent, generator=rng
396
+ ).prev_sample
397
 
398
  depth = self._decode_depth(depth_latent)
399
 
requirements.txt CHANGED
@@ -1,16 +1,16 @@
1
  gradio==4.21.0
2
- gradio-imageslider==0.0.16
3
  pygltflib==1.16.1
4
  trimesh==4.0.5
5
  imageio
6
  imageio-ffmpeg
7
  Pillow
8
 
9
- spaces>=0.25.0
10
- accelerate>=0.22.0
11
  diffusers==0.27.2
12
  matplotlib==3.8.2
13
  scipy==1.11.4
14
  torch==2.0.1
15
- transformers>=4.32.1
16
- xformers>=0.0.21
 
1
  gradio==4.21.0
2
+ gradio-imageslider==0.0.18
3
  pygltflib==1.16.1
4
  trimesh==4.0.5
5
  imageio
6
  imageio-ffmpeg
7
  Pillow
8
 
9
+ spaces==0.25.0
10
+ accelerate==0.25.0
11
  diffusers==0.27.2
12
  matplotlib==3.8.2
13
  scipy==1.11.4
14
  torch==2.0.1
15
+ transformers==4.36.1
16
+ xformers==0.0.21