File size: 4,309 Bytes
2c83504 4cce5d7 2c83504 0ea72b7 2c83504 4cce5d7 2c83504 4cce5d7 2c83504 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import os
import shutil
import gradio as gr
desc = """
<p align="center">
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-website.svg">
</a>
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
</a>
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
</a>
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
</a>
</p>
<p align="justify">
Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.
</p>
"""
def download_code():
os.system('git clone https://github.com/prs-eth/Marigold.git')
def find_first_png(directory):
for file in os.listdir(directory):
if file.lower().endswith(".png"):
return os.path.join(directory, file)
return None
def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None):
if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None:
return path_out_png, path_out_obj, path_out_2_png
path_input_dir = path_input + ".input"
path_output_dir = path_input + ".output"
os.makedirs(path_input_dir, exist_ok=True)
os.makedirs(path_output_dir, exist_ok=True)
shutil.copy(path_input, path_input_dir)
os.system(
f"cd Marigold && python3 run.py "
f"--input_rgb_dir \"{path_input_dir}\" "
f"--output_dir \"{path_output_dir}\" "
f"--n_infer 10 "
f"--denoise_steps 10 "
)
path_out_colored = find_first_png(path_output_dir + "/depth_colored")
assert path_out_colored is not None, "Processing failed"
path_out_bw = find_first_png(path_output_dir + "/depth_bw")
assert path_out_bw is not None, "Processing failed"
return path_out_colored, path_out_bw
iface = gr.Interface(
title="Marigold Depth Estimation",
description=desc,
thumbnail="marigold_logo_square.jpg",
fn=marigold_process,
inputs=[
gr.Image(
label="Input Image",
type="filepath",
),
gr.File(
label="Predicted depth (red-near, blue-far)",
visible=False,
),
gr.File(
label="Predicted depth (16-bit PNG)",
visible=False,
),
],
outputs=[
gr.Image(
label="Predicted depth (red-near, blue-far)",
type="pil",
),
gr.Image(
label="Predicted depth (16-bit PNG)",
type="pil",
elem_classes="imgdownload",
),
],
allow_flagging="never",
examples=[
[
os.path.join(os.path.dirname(__file__), "files/bee.jpg"),
os.path.join(os.path.dirname(__file__), "files/bee_vis.png"),
os.path.join(os.path.dirname(__file__), "files/bee_pred.png"),
],
[
os.path.join(os.path.dirname(__file__), "files/cat.jpg"),
os.path.join(os.path.dirname(__file__), "files/cat_vis.png"),
os.path.join(os.path.dirname(__file__), "files/cat_pred.png"),
],
[
os.path.join(os.path.dirname(__file__), "files/swings.jpg"),
os.path.join(os.path.dirname(__file__), "files/swings_vis.png"),
os.path.join(os.path.dirname(__file__), "files/swings_pred.png"),
],
],
css="""
.viewport {
aspect-ratio: 4/3;
}
.imgdownload {
height: 32px;
}
""",
cache_examples=True,
)
if __name__ == "__main__":
download_code()
iface.queue().launch(server_name="0.0.0.0", server_port=7860)
|