Spaces:
Running
Running
# Copyright 2024 Anton Obukhov, ETH Zurich. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
# More information about the method can be found at https://marigoldmonodepth.github.io | |
# -------------------------------------------------------------------------- | |
from __future__ import annotations | |
import functools | |
import os | |
import tempfile | |
import warnings | |
import zipfile | |
from io import BytesIO | |
import diffusers | |
import gradio as gr | |
import imageio as imageio | |
import numpy as np | |
import spaces | |
import torch as torch | |
from PIL import Image | |
from diffusers import MarigoldDepthPipeline | |
from gradio_imageslider import ImageSlider | |
from huggingface_hub import login | |
from tqdm import tqdm | |
from extrude import extrude_depth_3d | |
from gradio_patches.examples import Examples | |
from gradio_patches.flagging import FlagMethod, HuggingFaceDatasetSaver | |
warnings.filterwarnings( | |
"ignore", message=".*LoginButton created outside of a Blocks context.*" | |
) | |
default_seed = 2024 | |
default_batch_size = 4 | |
default_image_num_inference_steps = 4 | |
default_image_ensemble_size = 1 | |
default_image_processing_resolution = 768 | |
default_image_reproducuble = True | |
default_video_depth_latent_init_strength = 0.1 | |
default_video_num_inference_steps = 1 | |
default_video_ensemble_size = 1 | |
default_video_processing_resolution = 768 | |
default_video_out_max_frames = 450 | |
default_bas_plane_near = 0.0 | |
default_bas_plane_far = 1.0 | |
default_bas_embossing = 20 | |
default_bas_num_inference_steps = 4 | |
default_bas_ensemble_size = 1 | |
default_bas_processing_resolution = 768 | |
default_bas_size_longest_px = 512 | |
default_bas_size_longest_cm = 10 | |
default_bas_filter_size = 3 | |
default_bas_frame_thickness = 5 | |
default_bas_frame_near = 1 | |
default_bas_frame_far = 1 | |
default_share_always_show_hf_logout_btn = True | |
default_share_always_show_accordion = False | |
def process_image_check(path_input): | |
if path_input is None: | |
raise gr.Error( | |
"Missing image in the first pane: upload a file or use one from the gallery below." | |
) | |
def process_image( | |
pipe, | |
path_input, | |
num_inference_steps=default_image_num_inference_steps, | |
ensemble_size=default_image_ensemble_size, | |
processing_resolution=default_image_processing_resolution, | |
): | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing image {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy") | |
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png") | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png") | |
input_image = Image.open(path_input) | |
generator = torch.Generator(device=pipe.device).manual_seed(default_seed) | |
pipe_out = pipe( | |
input_image, | |
num_inference_steps=num_inference_steps, | |
ensemble_size=ensemble_size, | |
processing_resolution=processing_resolution, | |
batch_size=1 if processing_resolution == 0 else default_batch_size, | |
generator=generator, | |
) | |
depth_pred = pipe_out.prediction[0, :, :, 0] | |
depth_colored = pipe.image_processor.visualize_depth(pipe_out.prediction)[0] | |
depth_16bit = pipe.image_processor.export_depth_to_16bit_png(pipe_out.prediction)[0] | |
np.save(path_out_fp32, depth_pred) | |
depth_16bit.save(path_out_16bit) | |
depth_colored.save(path_out_vis) | |
return ( | |
[path_out_16bit, path_out_vis], | |
[path_out_16bit, path_out_fp32, path_out_vis], | |
) | |
def process_video( | |
pipe, | |
path_input, | |
depth_latent_init_strength=default_video_depth_latent_init_strength, | |
num_inference_steps=default_video_num_inference_steps, | |
ensemble_size=default_video_ensemble_size, | |
processing_resolution=default_video_processing_resolution, | |
out_max_frames=default_video_out_max_frames, | |
progress=gr.Progress(), | |
): | |
if path_input is None: | |
raise gr.Error( | |
"Missing video in the first pane: upload a file or use one from the gallery below." | |
) | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing video {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.mp4") | |
path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.zip") | |
generator = torch.Generator(device=pipe.device).manual_seed(default_seed) | |
reader, writer, zipf = None, None, None | |
try: | |
pipe.vae, pipe.vae_tiny = pipe.vae_tiny, pipe.vae | |
reader = imageio.get_reader(path_input) | |
meta_data = reader.get_meta_data() | |
fps = meta_data["fps"] | |
size = meta_data["size"] | |
max_orig = max(size) | |
duration_sec = meta_data["duration"] | |
total_frames = int(fps * duration_sec) | |
out_duration_sec = out_max_frames / fps | |
if duration_sec > out_duration_sec: | |
gr.Warning( | |
f"Only the first ~{int(out_duration_sec)} seconds will be processed; " | |
f"use alternative setups such as ComfyUI Marigold node for full processing" | |
) | |
writer = imageio.get_writer(path_out_vis, fps=fps) | |
zipf = zipfile.ZipFile(path_out_16bit, "w", zipfile.ZIP_DEFLATED) | |
last_frame_latent = None | |
latent_common = torch.randn( | |
( | |
1, | |
4, | |
(768 * size[1] + 7 * max_orig) // (8 * max_orig), | |
(768 * size[0] + 7 * max_orig) // (8 * max_orig), | |
), | |
generator=generator, | |
device=pipe.device, | |
dtype=torch.float16, | |
) | |
out_frame_id = 0 | |
pbar = tqdm(desc="Processing Video", total=min(out_max_frames, total_frames)) | |
for frame_id, frame in enumerate(reader): | |
out_frame_id += 1 | |
pbar.update(1) | |
if out_frame_id > out_max_frames: | |
break | |
frame_pil = Image.fromarray(frame) | |
latents = latent_common | |
if last_frame_latent is not None: | |
assert ( | |
last_frame_latent.shape == latent_common.shape | |
), f"{last_frame_latent.shape}, {latent_common.shape}" | |
latents = ( | |
1 - depth_latent_init_strength | |
) * latents + depth_latent_init_strength * last_frame_latent | |
pipe_out = pipe( | |
frame_pil, | |
num_inference_steps=num_inference_steps, | |
ensemble_size=ensemble_size, | |
processing_resolution=processing_resolution, | |
match_input_resolution=False, | |
batch_size=1, | |
latents=latents, | |
output_latent=True, | |
) | |
last_frame_latent = pipe_out.latent | |
processed_frame = pipe.image_processor.visualize_depth( # noqa | |
pipe_out.prediction | |
)[0] | |
processed_frame = imageio.core.util.Array(np.array(processed_frame)) | |
writer.append_data(processed_frame) | |
archive_path = os.path.join( | |
f"{name_base}_depth_16bit", f"{out_frame_id:05d}.png" | |
) | |
img_byte_arr = BytesIO() | |
processed_frame = pipe.image_processor.export_depth_to_16bit_png( | |
pipe_out.prediction | |
)[0] | |
processed_frame.save(img_byte_arr, format="png") | |
img_byte_arr.seek(0) | |
zipf.writestr(archive_path, img_byte_arr.read()) | |
finally: | |
if zipf is not None: | |
zipf.close() | |
if writer is not None: | |
writer.close() | |
if reader is not None: | |
reader.close() | |
pipe.vae, pipe.vae_tiny = pipe.vae_tiny, pipe.vae | |
return ( | |
path_out_vis, | |
[path_out_vis, path_out_16bit], | |
) | |
def process_bas( | |
pipe, | |
path_input, | |
plane_near=default_bas_plane_near, | |
plane_far=default_bas_plane_far, | |
embossing=default_bas_embossing, | |
num_inference_steps=default_bas_num_inference_steps, | |
ensemble_size=default_bas_ensemble_size, | |
processing_resolution=default_bas_processing_resolution, | |
size_longest_px=default_bas_size_longest_px, | |
size_longest_cm=default_bas_size_longest_cm, | |
filter_size=default_bas_filter_size, | |
frame_thickness=default_bas_frame_thickness, | |
frame_near=default_bas_frame_near, | |
frame_far=default_bas_frame_far, | |
): | |
if path_input is None: | |
raise gr.Error( | |
"Missing image in the first pane: upload a file or use one from the gallery below." | |
) | |
if plane_near >= plane_far: | |
raise gr.Error("NEAR plane must have a value smaller than the FAR plane") | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
print(f"Processing bas-relief {name_base}{name_ext}") | |
path_output_dir = tempfile.mkdtemp() | |
input_image = Image.open(path_input) | |
generator = torch.Generator(device=pipe.device).manual_seed(default_seed) | |
pipe_out = pipe( | |
input_image, | |
num_inference_steps=num_inference_steps, | |
ensemble_size=ensemble_size, | |
processing_resolution=processing_resolution, | |
generator=generator, | |
) | |
depth_pred = pipe_out.prediction[0, :, :, 0] * 65535 | |
def _process_3d( | |
size_longest_px, | |
filter_size, | |
vertex_colors, | |
scene_lights, | |
output_model_scale=None, | |
prepare_for_3d_printing=False, | |
zip_outputs=False, | |
): | |
image_rgb_w, image_rgb_h = input_image.width, input_image.height | |
image_rgb_d = max(image_rgb_w, image_rgb_h) | |
image_new_w = size_longest_px * image_rgb_w // image_rgb_d | |
image_new_h = size_longest_px * image_rgb_h // image_rgb_d | |
image_rgb_new = os.path.join( | |
path_output_dir, f"{name_base}_rgb_{size_longest_px}{name_ext}" | |
) | |
image_depth_new = os.path.join( | |
path_output_dir, f"{name_base}_depth_{size_longest_px}.png" | |
) | |
input_image.resize((image_new_w, image_new_h), Image.LANCZOS).save( | |
image_rgb_new | |
) | |
Image.fromarray(depth_pred).convert(mode="F").resize( | |
(image_new_w, image_new_h), Image.BILINEAR | |
).convert("I").save(image_depth_new) | |
path_glb, path_stl, path_obj = extrude_depth_3d( | |
image_rgb_new, | |
image_depth_new, | |
output_model_scale=( | |
size_longest_cm * 10 | |
if output_model_scale is None | |
else output_model_scale | |
), | |
filter_size=filter_size, | |
coef_near=plane_near, | |
coef_far=plane_far, | |
emboss=embossing / 100, | |
f_thic=frame_thickness / 100, | |
f_near=frame_near / 100, | |
f_back=frame_far / 100, | |
vertex_colors=vertex_colors, | |
scene_lights=scene_lights, | |
prepare_for_3d_printing=prepare_for_3d_printing, | |
zip_outputs=zip_outputs, | |
) | |
return path_glb, path_stl, path_obj | |
path_viewer_glb, _, _ = _process_3d( | |
256, filter_size, vertex_colors=False, scene_lights=True, output_model_scale=1 | |
) | |
path_files_glb, path_files_stl, path_files_obj = _process_3d( | |
size_longest_px, | |
filter_size, | |
vertex_colors=True, | |
scene_lights=False, | |
prepare_for_3d_printing=True, | |
zip_outputs=True, | |
) | |
return path_viewer_glb, [path_files_glb, path_files_stl, path_files_obj] | |
def run_demo_server(pipe, hf_writer=None): | |
process_pipe_image = spaces.GPU(functools.partial(process_image, pipe)) | |
process_pipe_video = spaces.GPU( | |
functools.partial(process_video, pipe), duration=120 | |
) | |
process_pipe_bas = spaces.GPU(functools.partial(process_bas, pipe)) | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="Marigold-LCM Depth Estimation", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as demo: | |
if hf_writer is not None: | |
print("Creating login button") | |
share_login_btn = gr.LoginButton(size="sm", scale=1, render=False) | |
print("Created login button") | |
share_login_btn.activate() | |
print("Activated login button") | |
gr.Markdown( | |
""" | |
# Marigold-LCM Depth Estimation | |
<p align="center"> | |
<a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-website.svg"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-pdf.svg"> | |
</a> | |
<a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" | |
alt="badge-github-stars"> | |
</a> | |
<a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" | |
style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
</p> | |
<p align="justify"> | |
Marigold-LCM is the fast version of Marigold, the state-of-the-art depth estimator for images in the | |
wild. It combines the power of the original Marigold 10-step estimator and the Latent Consistency | |
Models, delivering high-quality results in as little as <b>one step</b>. We provide three functions | |
in this demo: Image, Video, and Bas-relief 3D processing — <b>see the tabs below</b>. Upload your | |
content into the <b>first</b> pane, or click any of the <b>examples</b> below. Wait a second (for | |
images and 3D) or a minute (for videos), and interact with the result in the <b>second</b> pane. To | |
avoid queuing, fork the demo into your profile. | |
<a href="https://huggingface.co/spaces/prs-eth/marigold"> | |
The original Marigold demo is also available | |
</a>. | |
</p> | |
""" | |
) | |
def get_share_instructions(is_full): | |
out = ( | |
"### Help us improve Marigold! If the output is not what you expected, " | |
"you can help us by sharing it with us privately.\n" | |
) | |
if is_full: | |
out += ( | |
"1. Sign into your Hugging Face account using the button below.\n" | |
"1. Signing in may reset the demo and results; in that case, process the image again.\n" | |
) | |
out += "1. Review and agree to the terms of usage and enter an optional message to us.\n" | |
out += "1. Click the 'Share' button to submit the image to us privately.\n" | |
return out | |
def get_share_conditioned_on_login(profile: gr.OAuthProfile | None): | |
state_logged_out = profile is None | |
return get_share_instructions(is_full=state_logged_out), gr.Button( | |
visible=(state_logged_out or default_share_always_show_hf_logout_btn) | |
) | |
with gr.Tabs(elem_classes=["tabs"]): | |
with gr.Tab("Image"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
with gr.Row(): | |
image_submit_btn = gr.Button( | |
value="Compute Depth", variant="primary" | |
) | |
image_reset_btn = gr.Button(value="Reset") | |
with gr.Accordion("Advanced options", open=False): | |
image_num_inference_steps = gr.Slider( | |
label="Number of denoising steps", | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=default_image_num_inference_steps, | |
) | |
image_ensemble_size = gr.Slider( | |
label="Ensemble size", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=default_image_ensemble_size, | |
) | |
image_processing_resolution = gr.Radio( | |
[ | |
("Native", 0), | |
("Recommended", 768), | |
], | |
label="Processing resolution", | |
value=default_image_processing_resolution, | |
) | |
with gr.Column(): | |
image_output_slider = ImageSlider( | |
label="Predicted depth (red-near, blue-far)", | |
type="filepath", | |
show_download_button=True, | |
show_share_button=True, | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
) | |
image_output_files = gr.Files( | |
label="Depth outputs", | |
elem_id="download", | |
interactive=False, | |
) | |
if hf_writer is not None: | |
with gr.Accordion( | |
"Feedback", | |
open=False, | |
visible=default_share_always_show_accordion, | |
) as share_box: | |
share_instructions = gr.Markdown( | |
get_share_instructions(is_full=True), | |
elem_classes="md_feedback", | |
) | |
share_transfer_of_rights = gr.Checkbox( | |
label="(Optional) I own or hold necessary rights to the submitted image. By " | |
"checking this box, I grant an irrevocable, non-exclusive, transferable, " | |
"royalty-free, worldwide license to use the uploaded image, including for " | |
"publishing, reproducing, and model training. [transfer_of_rights]", | |
scale=1, | |
) | |
share_content_is_legal = gr.Checkbox( | |
label="By checking this box, I acknowledge that my uploaded content is legal and " | |
"safe, and that I am solely responsible for ensuring it complies with all " | |
"applicable laws and regulations. Additionally, I am aware that my Hugging Face " | |
"username is collected. [content_is_legal]", | |
scale=1, | |
) | |
share_reason = gr.Textbox( | |
label="(Optional) Reason for feedback", | |
max_lines=1, | |
interactive=True, | |
) | |
with gr.Row(): | |
share_login_btn.render() | |
share_share_btn = gr.Button( | |
"Share", variant="stop", scale=1 | |
) | |
Examples( | |
fn=process_pipe_image, | |
examples=[ | |
os.path.join("files", "image", name) | |
for name in [ | |
"arc.jpeg", | |
"berries.jpeg", | |
"butterfly.jpeg", | |
"cat.jpg", | |
"concert.jpeg", | |
"dog.jpeg", | |
"doughnuts.jpeg", | |
"einstein.jpg", | |
"food.jpeg", | |
"glasses.jpeg", | |
"house.jpg", | |
"lake.jpeg", | |
"marigold.jpeg", | |
"portrait_1.jpeg", | |
"portrait_2.jpeg", | |
"pumpkins.jpg", | |
"puzzle.jpeg", | |
"road.jpg", | |
"scientists.jpg", | |
"surfboards.jpeg", | |
"surfer.jpeg", | |
"swings.jpg", | |
"switzerland.jpeg", | |
"teamwork.jpeg", | |
"wave.jpeg", | |
] | |
], | |
inputs=[image_input], | |
outputs=[image_output_slider, image_output_files], | |
cache_examples=True, | |
directory_name="examples_image", | |
) | |
with gr.Tab("Video"): | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video( | |
label="Input Video", | |
sources=["upload"], | |
) | |
with gr.Row(): | |
video_submit_btn = gr.Button( | |
value="Compute Depth", variant="primary" | |
) | |
video_reset_btn = gr.Button(value="Reset") | |
with gr.Column(): | |
video_output_video = gr.Video( | |
label="Output video depth (red-near, blue-far)", | |
interactive=False, | |
) | |
video_output_files = gr.Files( | |
label="Depth outputs", | |
elem_id="download", | |
interactive=False, | |
) | |
Examples( | |
fn=process_pipe_video, | |
examples=[ | |
os.path.join("files", "video", name) | |
for name in [ | |
"cab.mp4", | |
"elephant.mp4", | |
"obama.mp4", | |
] | |
], | |
inputs=[video_input], | |
outputs=[video_output_video, video_output_files], | |
cache_examples=True, | |
directory_name="examples_video", | |
) | |
with gr.Tab("Bas-relief (3D)"): | |
gr.Markdown( | |
""" | |
<p align="justify"> | |
This part of the demo uses Marigold-LCM to create a bas-relief model. | |
The models are watertight, with correct normals, and exported in the STL format, which makes | |
them <b>3D-printable</b>. | |
</p> | |
""", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
bas_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
with gr.Row(): | |
bas_submit_btn = gr.Button( | |
value="Create 3D", variant="primary" | |
) | |
bas_reset_btn = gr.Button(value="Reset") | |
with gr.Accordion("3D printing demo: Main options", open=True): | |
bas_plane_near = gr.Slider( | |
label="Relative position of the near plane (between 0 and 1)", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.001, | |
value=default_bas_plane_near, | |
) | |
bas_plane_far = gr.Slider( | |
label="Relative position of the far plane (between near and 1)", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.001, | |
value=default_bas_plane_far, | |
) | |
bas_embossing = gr.Slider( | |
label="Embossing level", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=default_bas_embossing, | |
) | |
with gr.Accordion( | |
"3D printing demo: Advanced options", open=False | |
): | |
bas_num_inference_steps = gr.Slider( | |
label="Number of denoising steps", | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=default_bas_num_inference_steps, | |
) | |
bas_ensemble_size = gr.Slider( | |
label="Ensemble size", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=default_bas_ensemble_size, | |
) | |
bas_processing_resolution = gr.Radio( | |
[ | |
("Native", 0), | |
("Recommended", 768), | |
], | |
label="Processing resolution", | |
value=default_bas_processing_resolution, | |
) | |
bas_size_longest_px = gr.Slider( | |
label="Size (px) of the longest side", | |
minimum=256, | |
maximum=1024, | |
step=256, | |
value=default_bas_size_longest_px, | |
) | |
bas_size_longest_cm = gr.Slider( | |
label="Size (cm) of the longest side", | |
minimum=1, | |
maximum=100, | |
step=1, | |
value=default_bas_size_longest_cm, | |
) | |
bas_filter_size = gr.Slider( | |
label="Size (px) of the smoothing filter", | |
minimum=1, | |
maximum=5, | |
step=2, | |
value=default_bas_filter_size, | |
) | |
bas_frame_thickness = gr.Slider( | |
label="Frame thickness", | |
minimum=0, | |
maximum=100, | |
step=1, | |
value=default_bas_frame_thickness, | |
) | |
bas_frame_near = gr.Slider( | |
label="Frame's near plane offset", | |
minimum=-100, | |
maximum=100, | |
step=1, | |
value=default_bas_frame_near, | |
) | |
bas_frame_far = gr.Slider( | |
label="Frame's far plane offset", | |
minimum=1, | |
maximum=10, | |
step=1, | |
value=default_bas_frame_far, | |
) | |
with gr.Column(): | |
bas_output_viewer = gr.Model3D( | |
camera_position=(75.0, 90.0, 1.25), | |
elem_classes="viewport", | |
label="3D preview (low-res, relief highlight)", | |
interactive=False, | |
) | |
bas_output_files = gr.Files( | |
label="3D model outputs (high-res)", | |
elem_id="download", | |
interactive=False, | |
) | |
Examples( | |
fn=process_pipe_bas, | |
examples=[ | |
[ | |
"files/basrelief/coin.jpg", # input | |
0.0, # plane_near | |
0.66, # plane_far | |
15, # embossing | |
4, # num_inference_steps | |
4, # ensemble_size | |
768, # processing_resolution | |
512, # size_longest_px | |
10, # size_longest_cm | |
3, # filter_size | |
5, # frame_thickness | |
0, # frame_near | |
1, # frame_far | |
], | |
[ | |
"files/basrelief/einstein.jpg", # input | |
0.0, # plane_near | |
0.5, # plane_far | |
50, # embossing | |
2, # num_inference_steps | |
1, # ensemble_size | |
768, # processing_resolution | |
512, # size_longest_px | |
10, # size_longest_cm | |
3, # filter_size | |
5, # frame_thickness | |
-25, # frame_near | |
1, # frame_far | |
], | |
[ | |
"files/basrelief/food.jpeg", # input | |
0.0, # plane_near | |
1.0, # plane_far | |
20, # embossing | |
2, # num_inference_steps | |
4, # ensemble_size | |
768, # processing_resolution | |
512, # size_longest_px | |
10, # size_longest_cm | |
3, # filter_size | |
5, # frame_thickness | |
-5, # frame_near | |
1, # frame_far | |
], | |
], | |
inputs=[ | |
bas_input, | |
bas_plane_near, | |
bas_plane_far, | |
bas_embossing, | |
bas_num_inference_steps, | |
bas_ensemble_size, | |
bas_processing_resolution, | |
bas_size_longest_px, | |
bas_size_longest_cm, | |
bas_filter_size, | |
bas_frame_thickness, | |
bas_frame_near, | |
bas_frame_far, | |
], | |
outputs=[bas_output_viewer, bas_output_files], | |
cache_examples=True, | |
directory_name="examples_bas", | |
) | |
### Image tab | |
if hf_writer is not None: | |
image_submit_btn.click( | |
fn=process_image_check, | |
inputs=image_input, | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
get_share_conditioned_on_login, | |
None, | |
[share_instructions, share_login_btn], | |
queue=False, | |
).then( | |
lambda: ( | |
gr.Button(value="Share", interactive=True), | |
gr.Accordion(visible=True), | |
False, | |
False, | |
"", | |
), | |
None, | |
[ | |
share_share_btn, | |
share_box, | |
share_transfer_of_rights, | |
share_content_is_legal, | |
share_reason, | |
], | |
queue=False, | |
).then( | |
fn=process_pipe_image, | |
inputs=[ | |
image_input, | |
image_num_inference_steps, | |
image_ensemble_size, | |
image_processing_resolution, | |
], | |
outputs=[image_output_slider, image_output_files], | |
concurrency_limit=1, | |
) | |
else: | |
image_submit_btn.click( | |
fn=process_image_check, | |
inputs=image_input, | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
fn=process_pipe_image, | |
inputs=[ | |
image_input, | |
image_num_inference_steps, | |
image_ensemble_size, | |
image_processing_resolution, | |
], | |
outputs=[image_output_slider, image_output_files], | |
concurrency_limit=1, | |
) | |
image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
None, | |
default_image_ensemble_size, | |
default_image_num_inference_steps, | |
default_image_processing_resolution, | |
), | |
inputs=[], | |
outputs=[ | |
image_input, | |
image_output_slider, | |
image_output_files, | |
image_ensemble_size, | |
image_num_inference_steps, | |
image_processing_resolution, | |
], | |
queue=False, | |
) | |
if hf_writer is not None: | |
image_reset_btn.click( | |
fn=lambda: ( | |
gr.Button(value="Share", interactive=True), | |
gr.Accordion(visible=default_share_always_show_accordion), | |
), | |
inputs=[], | |
outputs=[ | |
share_share_btn, | |
share_box, | |
], | |
queue=False, | |
) | |
### Share functionality | |
if hf_writer is not None: | |
share_components = [ | |
image_input, | |
image_num_inference_steps, | |
image_ensemble_size, | |
image_processing_resolution, | |
image_output_slider, | |
share_content_is_legal, | |
share_transfer_of_rights, | |
share_reason, | |
] | |
hf_writer.setup(share_components, "shared_data") | |
share_callback = FlagMethod(hf_writer, "Share", "", visual_feedback=True) | |
def share_precheck( | |
hf_content_is_legal, | |
image_output_slider, | |
profile: gr.OAuthProfile | None, | |
): | |
if profile is None: | |
raise gr.Error( | |
"Log into the Space with your Hugging Face account first." | |
) | |
if image_output_slider is None or image_output_slider[0] is None: | |
raise gr.Error("No output detected; process the image first.") | |
if not hf_content_is_legal: | |
raise gr.Error( | |
"You must consent that the uploaded content is legal." | |
) | |
return gr.Button(value="Sharing in progress", interactive=False) | |
share_share_btn.click( | |
share_precheck, | |
[share_content_is_legal, image_output_slider], | |
share_share_btn, | |
preprocess=False, | |
queue=False, | |
).success( | |
share_callback, | |
inputs=share_components, | |
outputs=share_share_btn, | |
preprocess=False, | |
queue=False, | |
) | |
### Video tab | |
video_submit_btn.click( | |
fn=process_pipe_video, | |
inputs=[video_input], | |
outputs=[video_output_video, video_output_files], | |
concurrency_limit=1, | |
) | |
video_reset_btn.click( | |
fn=lambda: (None, None, None), | |
inputs=[], | |
outputs=[video_input, video_output_video, video_output_files], | |
concurrency_limit=1, | |
) | |
### Bas-relief tab | |
bas_submit_btn.click( | |
fn=process_pipe_bas, | |
inputs=[ | |
bas_input, | |
bas_plane_near, | |
bas_plane_far, | |
bas_embossing, | |
bas_num_inference_steps, | |
bas_ensemble_size, | |
bas_processing_resolution, | |
bas_size_longest_px, | |
bas_size_longest_cm, | |
bas_filter_size, | |
bas_frame_thickness, | |
bas_frame_near, | |
bas_frame_far, | |
], | |
outputs=[bas_output_viewer, bas_output_files], | |
concurrency_limit=1, | |
) | |
bas_reset_btn.click( | |
fn=lambda: ( | |
gr.Button(interactive=True), | |
None, | |
None, | |
None, | |
default_bas_plane_near, | |
default_bas_plane_far, | |
default_bas_embossing, | |
default_bas_num_inference_steps, | |
default_bas_ensemble_size, | |
default_bas_processing_resolution, | |
default_bas_size_longest_px, | |
default_bas_size_longest_cm, | |
default_bas_filter_size, | |
default_bas_frame_thickness, | |
default_bas_frame_near, | |
default_bas_frame_far, | |
), | |
inputs=[], | |
outputs=[ | |
bas_submit_btn, | |
bas_input, | |
bas_output_viewer, | |
bas_output_files, | |
bas_plane_near, | |
bas_plane_far, | |
bas_embossing, | |
bas_num_inference_steps, | |
bas_ensemble_size, | |
bas_processing_resolution, | |
bas_size_longest_px, | |
bas_size_longest_cm, | |
bas_filter_size, | |
bas_frame_thickness, | |
bas_frame_near, | |
bas_frame_far, | |
], | |
concurrency_limit=1, | |
) | |
### Server launch | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
def main(): | |
CHECKPOINT = "prs-eth/marigold-depth-lcm-v1-0" | |
CROWD_DATA = "crowddata-marigold-depth-lcm-v1-0-space-v1-0" | |
os.system("pip freeze") | |
if "HF_TOKEN_LOGIN" in os.environ: | |
login(token=os.environ["HF_TOKEN_LOGIN"]) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
pipe = MarigoldDepthPipeline.from_pretrained( | |
CHECKPOINT, variant="fp16", torch_dtype=torch.float16 | |
).to(device) | |
pipe.vae_tiny = diffusers.AutoencoderTiny.from_pretrained( | |
"madebyollin/taesd", torch_dtype=torch.float16 | |
).to(device) | |
pipe.set_progress_bar_config(disable=True) | |
try: | |
import xformers | |
pipe.enable_xformers_memory_efficient_attention() | |
except: | |
pass # run without xformers | |
hf_writer = None | |
if "HF_TOKEN_LOGIN_WRITE_CROWD" in os.environ: | |
hf_writer = HuggingFaceDatasetSaver( | |
os.getenv("HF_TOKEN_LOGIN_WRITE_CROWD"), | |
CROWD_DATA, | |
private=True, | |
info_filename="dataset_info.json", | |
separate_dirs=True, | |
) | |
run_demo_server(pipe, hf_writer) | |
if __name__ == "__main__": | |
main() | |