import os
import shutil

import gradio as gr


desc = """
    <p align="center">
    <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
    </a>
    <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
    </a>
    <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
    </a>
    <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
        <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
    </a>
    </p>
    <p align="justify">
    Marigold is the new state-of-the-art depth estimator for images in the wild. Upload your image into the pane on the left side, or expore examples listed in the bottom.  
    </p>
"""


def download_code():
    os.system('git clone https://github.com/prs-eth/Marigold.git')


def find_first_png(directory):
    for file in os.listdir(directory):
        if file.lower().endswith(".png"):
            return os.path.join(directory, file)
    return None


def marigold_process(path_input, path_out_png=None, path_out_obj=None, path_out_2_png=None):
    if path_out_png is not None and path_out_obj is not None and path_out_2_png is not None:
        return path_out_png, path_out_obj, path_out_2_png

    path_input_dir = path_input + ".input"
    path_output_dir = path_input + ".output"
    os.makedirs(path_input_dir, exist_ok=True)
    os.makedirs(path_output_dir, exist_ok=True)
    shutil.copy(path_input, path_input_dir)

    os.system(
        f"cd Marigold && python3 run.py "
        f"--input_rgb_dir \"{path_input_dir}\" "
        f"--output_dir \"{path_output_dir}\" "
        f"--n_infer 10 "
        f"--denoise_steps 10 "
    )

    path_out_colored = find_first_png(path_output_dir + "/depth_colored")
    assert path_out_colored is not None, "Processing failed"
    path_out_bw = find_first_png(path_output_dir + "/depth_bw")
    assert path_out_bw is not None, "Processing failed"

    return path_out_colored, path_out_bw


iface = gr.Interface(
    title="Marigold Depth Estimation",
    description=desc,
    thumbnail="marigold_logo_square.jpg",
    fn=marigold_process,
    inputs=[
        gr.Image(
            label="Input Image",
            type="filepath",
        ),
        gr.File(
            label="Predicted depth (red-near, blue-far)",
            visible=False,
        ),
        gr.File(
            label="Predicted depth (16-bit PNG)",
            visible=False,
        ),
    ],
    outputs=[
        gr.Image(
            label="Predicted depth (red-near, blue-far)",
            type="pil",
        ),
        gr.Image(
            label="Predicted depth (16-bit PNG)",
            type="pil",
            elem_classes="imgdownload",
        ),
    ],
    allow_flagging="never",
    examples=[
        [
            os.path.join(os.path.dirname(__file__), "files/bee.jpg"),
            os.path.join(os.path.dirname(__file__), "files/bee_vis.png"),
            os.path.join(os.path.dirname(__file__), "files/bee_pred.png"),
        ],
        [
            os.path.join(os.path.dirname(__file__), "files/cat.jpg"),
            os.path.join(os.path.dirname(__file__), "files/cat_vis.png"),
            os.path.join(os.path.dirname(__file__), "files/cat_pred.png"),
        ],
        [
            os.path.join(os.path.dirname(__file__), "files/swings.jpg"),
            os.path.join(os.path.dirname(__file__), "files/swings_vis.png"),
            os.path.join(os.path.dirname(__file__), "files/swings_pred.png"),
        ],
    ],
    css="""
    .viewport {
        aspect-ratio: 4/3;
    }
    .imgdownload {
        height: 32px;
    }
    """,
    cache_examples=True,
)


if __name__ == "__main__":
    download_code()
    iface.queue().launch(server_name="0.0.0.0", server_port=7860)