|
import spaces |
|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
import subprocess |
|
import tempfile |
|
import os |
|
import trimesh |
|
import time |
|
from datetime import datetime |
|
import pytz |
|
|
|
|
|
import torch |
|
import src.depth_pro as depth_pro |
|
import timm |
|
import cv2 |
|
|
|
print(f"Timm version: {timm.__version__}") |
|
|
|
subprocess.run(["bash", "get_pretrained_models.sh"]) |
|
|
|
@spaces.GPU(duration=30) |
|
def load_model_and_predict(image_path): |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
model, transform = depth_pro.create_model_and_transforms() |
|
model = model.to(device) |
|
model.eval() |
|
|
|
result = depth_pro.load_rgb(image_path) |
|
if len(result) < 2: |
|
raise ValueError(f"Unexpected result from load_rgb: {result}") |
|
|
|
image = result[0] |
|
f_px = result[-1] |
|
print(f"Extracted focal length: {f_px}") |
|
|
|
image = transform(image).to(device) |
|
|
|
with torch.no_grad(): |
|
prediction = model.infer(image, f_px=f_px) |
|
|
|
depth = prediction["depth"].cpu().numpy() |
|
focallength_px = prediction["focallength_px"] |
|
|
|
return depth, focallength_px |
|
|
|
def resize_image(image_path, max_size=1024): |
|
""" |
|
Resize the input image to ensure its largest dimension does not exceed max_size. |
|
Maintains the aspect ratio and saves the resized image as a temporary PNG file. |
|
|
|
Args: |
|
image_path (str): Path to the input image. |
|
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024. |
|
|
|
Returns: |
|
str: Path to the resized temporary image file. |
|
""" |
|
with Image.open(image_path) as img: |
|
|
|
ratio = max_size / max(img.size) |
|
new_size = tuple([int(x * ratio) for x in img.size]) |
|
|
|
|
|
img = img.resize(new_size, Image.LANCZOS) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: |
|
img.save(temp_file, format="PNG") |
|
return temp_file.name |
|
|
|
@spaces.GPU(duration=30) |
|
def generate_3d_model(depth, image_path, focallength_px, simplification_factor=1.0, smoothing_iterations=0, thin_threshold=0, enable_face_culling=False, culling_threshold=0.1): |
|
""" |
|
Generate a textured 3D mesh from the depth map and the original image. |
|
""" |
|
try: |
|
print("Starting 3D model generation") |
|
|
|
image = Image.open(image_path) |
|
image_array = np.array(image) |
|
|
|
|
|
if isinstance(depth, torch.Tensor): |
|
depth = depth.cpu().numpy() |
|
|
|
|
|
if depth.shape != image_array.shape[:2]: |
|
depth = cv2.resize(depth, (image_array.shape[1], image_array.shape[0]), interpolation=cv2.INTER_LINEAR) |
|
|
|
height, width = depth.shape |
|
|
|
print(f"3D model generation - Depth shape: {depth.shape}") |
|
print(f"3D model generation - Image shape: {image_array.shape}") |
|
|
|
|
|
fx = fy = float(focallength_px) |
|
cx, cy = width / 2, height / 2 |
|
|
|
|
|
u = np.arange(0, width) |
|
v = np.arange(0, height) |
|
uu, vv = np.meshgrid(u, v) |
|
|
|
|
|
Z = depth.flatten() |
|
X = ((uu.flatten() - cx) * Z) / fx |
|
Y = ((vv.flatten() - cy) * Z) / fy |
|
|
|
|
|
vertices = np.vstack((X, Y, Z)).T |
|
|
|
|
|
colors = image_array.reshape(-1, 3) / 255.0 |
|
|
|
print("Generating faces") |
|
|
|
faces = [] |
|
for i in range(height - 1): |
|
for j in range(width - 1): |
|
idx = i * width + j |
|
|
|
faces.append([idx, idx + width, idx + 1]) |
|
|
|
faces.append([idx + 1, idx + width, idx + width + 1]) |
|
faces = np.array(faces) |
|
|
|
print("Creating mesh") |
|
|
|
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors, process=False) |
|
|
|
|
|
if enable_face_culling: |
|
print(f"Culling faces with normal dot product < {culling_threshold}") |
|
mesh = cull_faces_by_normal(mesh, min_dot_product_threshold=culling_threshold) |
|
print("After face culling - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
|
|
if simplification_factor < 1.0 or smoothing_iterations > 0 or thin_threshold > 0: |
|
print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
if simplification_factor < 1.0: |
|
print("Simplifying mesh") |
|
target_faces = int(len(mesh.faces) * simplification_factor) |
|
mesh = mesh.simplify_quadric_decimation(face_count=target_faces) |
|
print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
if smoothing_iterations > 0: |
|
print("Smoothing mesh") |
|
for _ in range(smoothing_iterations): |
|
mesh = mesh.smoothed() |
|
print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
if thin_threshold > 0: |
|
print("Removing thin features") |
|
mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold) |
|
print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces))) |
|
|
|
|
|
timestamp = int(time.time()) |
|
view_model_path = f'view_model_{timestamp}.obj' |
|
download_model_path = f'download_model_{timestamp}.obj' |
|
|
|
print("Exporting to view") |
|
mesh.export(view_model_path, include_texture=True) |
|
print("Exporting to download") |
|
mesh.export(download_model_path, include_texture=True) |
|
|
|
|
|
texture_path = f'texture_{timestamp}.png' |
|
image.save(texture_path) |
|
|
|
print("Export completed") |
|
return view_model_path, download_model_path, texture_path |
|
except Exception as e: |
|
print(f"Error in generate_3d_model: {str(e)}") |
|
raise |
|
|
|
def cull_faces_by_normal(mesh, min_dot_product_threshold=0.1): |
|
""" |
|
Removes faces from the mesh that are nearly vertical, which often cause 'smearing' artifacts. |
|
|
|
Args: |
|
mesh (trimesh.Trimesh): The input mesh. |
|
min_dot_product_threshold (float): Faces with normals whose Z-component's absolute |
|
value is less than this threshold will be removed. |
|
This effectively removes faces that are nearly vertical. |
|
A value of 0.1 corresponds to removing faces that are |
|
within about 84 degrees of being vertical. |
|
|
|
Returns: |
|
trimesh.Trimesh: The mesh with offending faces removed. |
|
""" |
|
face_normals = mesh.face_normals |
|
|
|
view_vector = np.array([0, 0, 1]) |
|
|
|
|
|
dot_products = np.abs(np.dot(face_normals, view_vector)) |
|
|
|
|
|
keep_mask = dot_products > min_dot_product_threshold |
|
|
|
|
|
mesh.update_faces(keep_mask) |
|
mesh.remove_unreferenced_vertices() |
|
|
|
return mesh |
|
|
|
def remove_thin_features(mesh, thickness_threshold=0.01): |
|
""" |
|
Remove thin features from the mesh. |
|
""" |
|
|
|
edges = mesh.edges_unique |
|
edge_points = mesh.vertices[edges] |
|
edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1) |
|
|
|
|
|
short_edges = edges[edge_lengths < thickness_threshold] |
|
|
|
|
|
for edge in short_edges: |
|
try: |
|
mesh.collapse_edge(edge) |
|
except: |
|
pass |
|
|
|
|
|
mesh.remove_degenerate_faces() |
|
|
|
return mesh |
|
|
|
@spaces.GPU |
|
def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold): |
|
|
|
depth = np.loadtxt(depth_csv, delimiter=',') |
|
|
|
|
|
view_model_path, download_model_path, texture_path = generate_3d_model( |
|
depth, image_path, focallength_px, |
|
simplification_factor, smoothing_iterations, thin_threshold |
|
) |
|
print("regenerated!") |
|
return view_model_path, download_model_path, texture_path |
|
|
|
@spaces.GPU(duration=30) |
|
def predict_depth(input_image): |
|
temp_file = None |
|
try: |
|
print(f"Input image type: {type(input_image)}") |
|
print(f"Input image path: {input_image}") |
|
|
|
temp_file = resize_image(input_image) |
|
print(f"Resized image path: {temp_file}") |
|
|
|
depth, focallength_px = load_model_and_predict(temp_file) |
|
print(f"Raw depth type: {type(depth)}, focallength_px type: {type(focallength_px)}") |
|
|
|
if depth.ndim != 2: |
|
depth = depth.squeeze() |
|
|
|
print(f"Depth map shape: {depth.shape}") |
|
|
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(depth, cmap='gist_rainbow') |
|
plt.colorbar(label='Depth [m]') |
|
plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m') |
|
plt.axis('off') |
|
|
|
output_path = "depth_map.png" |
|
plt.savefig(output_path) |
|
plt.close() |
|
|
|
raw_depth_path = "raw_depth_map.csv" |
|
np.savetxt(raw_depth_path, depth.cpu().numpy() if isinstance(depth, torch.Tensor) else depth, delimiter=',') |
|
print(f"Saved raw depth map to {raw_depth_path}") |
|
|
|
focallength_px = float(focallength_px) |
|
print(f"Converted focallength_px to float: {focallength_px}") |
|
|
|
print("Depth map created!") |
|
print(f"Returning - output_path: {output_path}, focallength_px: {focallength_px}, raw_depth_path: {raw_depth_path}, temp_file: {temp_file}") |
|
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, focallength_px |
|
except Exception as e: |
|
import traceback |
|
error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
print(error_message) |
|
return None, error_message, None, None |
|
finally: |
|
if temp_file and os.path.exists(temp_file): |
|
os.remove(temp_file) |
|
print(f"Removed temporary file: {temp_file}") |
|
|
|
@spaces.GPU(duration=30) |
|
def create_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold, enable_face_culling=False, culling_threshold=0.1): |
|
try: |
|
depth = np.loadtxt(depth_csv, delimiter=',') |
|
|
|
|
|
if not os.path.exists(image_path): |
|
raise FileNotFoundError(f"Image file not found: {image_path}") |
|
|
|
print(f"Loading image from: {image_path}") |
|
|
|
view_model_path, download_model_path, texture_path = generate_3d_model( |
|
depth, image_path, focallength_px, |
|
simplification_factor, smoothing_iterations, thin_threshold, |
|
enable_face_culling, culling_threshold |
|
) |
|
print("3D model generated!") |
|
return view_model_path, download_model_path, texture_path, "3D model created successfully!" |
|
except Exception as e: |
|
error_message = f"An error occurred during 3D model creation: {str(e)}" |
|
print(error_message) |
|
return None, None, None, error_message |
|
|
|
def get_last_commit_timestamp(): |
|
try: |
|
|
|
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip() |
|
|
|
|
|
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S %z") |
|
|
|
|
|
dt_utc = dt.astimezone(pytz.UTC) |
|
|
|
|
|
return dt_utc.strftime("%Y-%m-%d %H:%M:%S UTC") |
|
except Exception as e: |
|
print(f"Error getting last commit timestamp: {str(e)}") |
|
return "Unknown" |
|
|
|
|
|
last_updated = get_last_commit_timestamp() |
|
|
|
with gr.Blocks() as iface: |
|
gr.Markdown("# DepthPro Demo with 3D Visualization by Alex Krause") |
|
gr.Markdown( |
|
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n" |
|
"Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n" |
|
"**Instructions:**\n" |
|
"1. Upload an image to generate the depth map.\n" |
|
"2. Click 'Generate 3D Model' to create the 3D visualization.\n" |
|
"3. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n" |
|
"4. Download the raw depth data as a CSV file or the 3D model as an OBJ file if desired.\n\n" |
|
f"Last updated: {last_updated}" |
|
) |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(type="filepath", label="Input Image") |
|
depth_map = gr.Image(type="filepath", label="Depth Map") |
|
|
|
focal_length = gr.Textbox(label="Focal Length") |
|
raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)") |
|
|
|
generate_3d_button = gr.Button("Generate 3D Model") |
|
|
|
with gr.Row(): |
|
view_3d_model = gr.Model3D(label="View 3D Model") |
|
download_3d_model = gr.File(label="Download 3D Model (OBJ)") |
|
download_texture = gr.File(label="Download Texture (PNG)") |
|
|
|
with gr.Row(): |
|
simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.1, label="Simplification Factor (1.0 = No simplification)") |
|
smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=0, step=1, label="Smoothing Iterations (0 = No smoothing)") |
|
thin_threshold = gr.Slider(minimum=0, maximum=0.1, value=0, step=0.001, label="Thin Feature Threshold (0 = No thin feature removal)") |
|
|
|
with gr.Row(): |
|
enable_face_culling = gr.Checkbox(label="Enable Face Culling", value=True) |
|
culling_threshold = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.01, label="Face Culling Threshold") |
|
|
|
regenerate_button = gr.Button("Regenerate 3D Model") |
|
model_status = gr.Textbox(label="3D Model Status") |
|
|
|
|
|
hidden_focal_length = gr.State() |
|
|
|
input_image.change( |
|
predict_depth, |
|
inputs=[input_image], |
|
outputs=[depth_map, focal_length, raw_depth_csv, hidden_focal_length] |
|
) |
|
|
|
generate_3d_button.click( |
|
create_3d_model, |
|
inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold, enable_face_culling, culling_threshold], |
|
outputs=[view_3d_model, download_3d_model, download_texture, model_status] |
|
) |
|
|
|
regenerate_button.click( |
|
create_3d_model, |
|
inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold, enable_face_culling, culling_threshold], |
|
outputs=[view_3d_model, download_3d_model, download_texture, model_status] |
|
) |
|
|
|
|
|
iface.launch(share=True) |