import os import shutil import gradio as gr desc = """

badge-github-stars social

Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.

""" def init_persistence(purge=False): if not os.path.exists('/data'): return os.environ['ckpt_dir'] = "/data/Marigold_ckpt" os.environ['TRANSFORMERS_CACHE'] = "/data/hfcache" os.environ['HF_DATASETS_CACHE'] = "/data/hfcache" os.environ['HF_HOME'] = "/data/hfcache" if purge: os.system("rm -rf /data/Marigold_ckpt/*") def download_code_weights(): os.system('git clone https://github.com/prs-eth/Marigold.git') os.system('cd Marigold && bash script/download_weights.sh') os.system('echo /data && ls -la /data') os.system('echo /data/Marigold_ckpt && ls -la /data/Marigold_ckpt') os.system('echo /data/Marigold_ckpt/Marigold_v1_merged && ls -la /data/Marigold_ckpt/Marigold_v1_merged') def find_first_png(directory): for file in os.listdir(directory): if file.lower().endswith(".png"): return os.path.join(directory, file) return None def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None): if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None: return path_out_png, path_out_obj, path_out_2_png path_input_dir = path_input + ".input" path_output_dir = path_input + ".output" os.makedirs(path_input_dir, exist_ok=True) os.makedirs(path_output_dir, exist_ok=True) shutil.copy(path_input, path_input_dir) persistence_args = "" if os.path.exists('/data'): persistence_args = "--checkpoint /data/Marigold_ckpt/Marigold_v1_merged" os.system( f"cd Marigold && python3 run.py " f"{persistence_args} " f"--input_rgb_dir \"{path_input_dir}\" " f"--output_dir \"{path_output_dir}\" " f"--n_infer 10 " f"--denoise_steps 10 " ) # depth_colored, depth_bw, depth_npy path_out_colored = find_first_png(path_output_dir + "/depth_colored") assert path_out_colored is not None, "Processing failed" path_out_bw = find_first_png(path_output_dir + "/depth_bw") assert path_out_bw is not None, "Processing failed" return path_out_colored, path_out_bw iface = gr.Interface( title="Marigold Depth Estimation", description=desc, thumbnail="marigold_logo_square.jpg", fn=marigold_process, inputs=[ gr.Image( label="Input Image", type="filepath", ), gr.File( label="Predicted depth (red-near, blue-far)", visible=False, ), gr.File( label="Predicted depth (16-bit PNG)", visible=False, ), ], outputs=[ gr.Image( label="Predicted depth (red-near, blue-far)", type="pil", ), gr.Image( label="Predicted depth (16-bit PNG)", type="pil", elem_classes="imgdownload", ), ], allow_flagging="never", # examples=[ # [ # os.path.join(os.path.dirname(__file__), "files/test.png"), # os.path.join(os.path.dirname(__file__), "files/test.png.out.png"), # os.path.join(os.path.dirname(__file__), "files/test.png.out.2.png"), # ], # ], css=""" .viewport { aspect-ratio: 4/3; } .imgdownload { height: 32px; } """, cache_examples=True, ) if __name__ == "__main__": init_persistence() download_code_weights() iface.queue().launch(server_name="0.0.0.0", server_port=7860)