Spaces:
Running
on
Zero
Running
on
Zero
File size: 14,320 Bytes
47669a5 cf8c487 dc9d69c ee04d83 f95c546 40c89eb 47669a5 6648275 47669a5 39f1439 40c89eb 024a2b8 dc9d69c cf8c487 3f2f727 39f1439 4322221 39f1439 764b436 39f1439 cf8c487 501d06f f95c546 501d06f f95c546 501d06f f95c546 501d06f f95c546 ee04d83 501d06f 423abd1 9194204 f95c546 7ce37c4 6648275 7ce37c4 6648275 7ce37c4 6648275 7ce37c4 6648275 7ce37c4 f49612a 7ce37c4 9194204 7ce37c4 6648275 7ce37c4 6648275 7ce37c4 6648275 7ce37c4 6648275 7ce37c4 f95c546 2b69b2a 423abd1 f9c3dad 6648275 f9c3dad 83e6e59 6648275 f9c3dad 3f2f727 cf8c487 ee04d83 404967f 29a026e ee04d83 29a026e 501d06f 39f1439 4322221 404967f b3b839e f95c546 404967f b3b839e 182cf21 b3b839e 39f1439 f95c546 404967f 182cf21 4322221 182cf21 3f2f727 4322221 9890878 404967f 29a026e 39f1439 9890878 ee04d83 4322221 cf8c487 423abd1 3f2f727 ced7c47 6648275 3f2f727 6648275 3f2f727 6648275 3f2f727 c6f3d95 6648275 c6f3d95 6648275 b3b839e 6648275 c6f3d95 d893f72 4bfe855 f9c3dad f95c546 4bfe855 f95c546 3f2f727 4bfe855 f9c3dad 3f2f727 f9c3dad 6648275 f9c3dad 9194204 f9c3dad 3f2f727 f9c3dad 9890878 f9c3dad 9890878 3f2f727 9890878 6648275 f9c3dad 3f2f727 9890878 6648275 f9c3dad cf8c487 f95c546 f9c3dad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
import spaces
import gradio as gr
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import tempfile
import os
import trimesh
import time
from datetime import datetime
import pytz
# Import potentially CUDA-initializing modules after 'spaces'
import torch
import src.depth_pro as depth_pro
import timm
import cv2
print(f"Timm version: {timm.__version__}")
subprocess.run(["bash", "get_pretrained_models.sh"])
@spaces.GPU(duration=30)
def load_model_and_predict(image_path):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()
result = depth_pro.load_rgb(image_path)
if len(result) < 2:
raise ValueError(f"Unexpected result from load_rgb: {result}")
image = result[0]
f_px = result[-1]
print(f"Extracted focal length: {f_px}")
image = transform(image).to(device)
with torch.no_grad():
prediction = model.infer(image, f_px=f_px)
depth = prediction["depth"].cpu().numpy()
focallength_px = prediction["focallength_px"]
return depth, focallength_px
def resize_image(image_path, max_size=1024):
"""
Resize the input image to ensure its largest dimension does not exceed max_size.
Maintains the aspect ratio and saves the resized image as a temporary PNG file.
Args:
image_path (str): Path to the input image.
max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.
Returns:
str: Path to the resized temporary image file.
"""
with Image.open(image_path) as img:
# Calculate the resizing ratio while maintaining aspect ratio
ratio = max_size / max(img.size)
new_size = tuple([int(x * ratio) for x in img.size])
# Resize the image using LANCZOS filter for high-quality downsampling
img = img.resize(new_size, Image.LANCZOS)
# Save the resized image to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
img.save(temp_file, format="PNG")
return temp_file.name
@spaces.GPU # Increased duration to default 60 seconds
def generate_3d_model(depth, image_path, focallength_px, simplification_factor=1.0, smoothing_iterations=0, thin_threshold=0):
"""
Generate a textured 3D mesh from the depth map and the original image.
"""
try:
print("Starting 3D model generation")
# Load the RGB image and convert to a NumPy array
image = Image.open(image_path)
image_array = np.array(image)
# Ensure depth is a NumPy array
if isinstance(depth, torch.Tensor):
depth = depth.cpu().numpy()
# Resize depth to match image dimensions if necessary
if depth.shape != image_array.shape[:2]:
depth = cv2.resize(depth, (image_array.shape[1], image_array.shape[0]), interpolation=cv2.INTER_LINEAR)
height, width = depth.shape
print(f"3D model generation - Depth shape: {depth.shape}")
print(f"3D model generation - Image shape: {image_array.shape}")
# Compute camera intrinsic parameters
fx = fy = float(focallength_px) # Ensure focallength_px is a float
cx, cy = width / 2, height / 2 # Principal point at the image center
# Create a grid of (u, v) pixel coordinates
u = np.arange(0, width)
v = np.arange(0, height)
uu, vv = np.meshgrid(u, v)
# Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
Z = depth.flatten()
X = ((uu.flatten() - cx) * Z) / fx
Y = ((vv.flatten() - cy) * Z) / fy
# Stack the coordinates to form vertices (X, Y, Z)
vertices = np.vstack((X, Y, Z)).T
# Normalize RGB colors to [0, 1] for vertex coloring
colors = image_array.reshape(-1, 3) / 255.0
print("Generating faces")
# Generate faces by connecting adjacent vertices to form triangles
faces = []
for i in range(height - 1):
for j in range(width - 1):
idx = i * width + j
# Triangle 1
faces.append([idx, idx + width, idx + 1])
# Triangle 2
faces.append([idx + 1, idx + width, idx + width + 1])
faces = np.array(faces)
print("Creating mesh")
# Create the mesh using Trimesh with vertex colors
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors, process=False)
# Mesh cleaning and improvement steps (only if not using default values)
if simplification_factor < 1.0 or smoothing_iterations > 0 or thin_threshold > 0:
print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
if simplification_factor < 1.0:
print("Simplifying mesh")
target_faces = int(len(mesh.faces) * simplification_factor)
mesh = mesh.simplify_quadric_decimation(face_count=target_faces)
print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
if smoothing_iterations > 0:
print("Smoothing mesh")
for _ in range(smoothing_iterations):
mesh = mesh.smoothed()
print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
if thin_threshold > 0:
print("Removing thin features")
mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold)
print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))
# Export the mesh to OBJ files with unique filenames
timestamp = int(time.time())
view_model_path = f'view_model_{timestamp}.obj'
download_model_path = f'download_model_{timestamp}.obj'
print("Exporting to view")
mesh.export(view_model_path, include_texture=True)
print("Exporting to download")
mesh.export(download_model_path, include_texture=True)
# Save the texture image
texture_path = f'texture_{timestamp}.png'
image.save(texture_path)
print("Export completed")
return view_model_path, download_model_path, texture_path
except Exception as e:
print(f"Error in generate_3d_model: {str(e)}")
raise
def remove_thin_features(mesh, thickness_threshold=0.01):
"""
Remove thin features from the mesh.
"""
# Calculate edge lengths
edges = mesh.edges_unique
edge_points = mesh.vertices[edges]
edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1)
# Identify short edges
short_edges = edges[edge_lengths < thickness_threshold]
# Collapse short edges
for edge in short_edges:
try:
mesh.collapse_edge(edge)
except:
pass # Skip if edge collapse fails
# Remove any newly created degenerate faces
mesh.remove_degenerate_faces()
return mesh
@spaces.GPU # Increased duration to default 60 seconds
def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold):
# Load depth from CSV
depth = np.loadtxt(depth_csv, delimiter=',')
# Generate new 3D model with updated parameters
view_model_path, download_model_path, texture_path = generate_3d_model(
depth, image_path, focallength_px,
simplification_factor, smoothing_iterations, thin_threshold
)
print("regenerated!")
return view_model_path, download_model_path, texture_path
@spaces.GPU(duration=30)
def predict_depth(input_image):
temp_file = None
try:
print(f"Input image type: {type(input_image)}")
print(f"Input image path: {input_image}")
temp_file = resize_image(input_image)
print(f"Resized image path: {temp_file}")
depth, focallength_px = load_model_and_predict(temp_file)
print(f"Raw depth type: {type(depth)}, focallength_px type: {type(focallength_px)}")
if depth.ndim != 2:
depth = depth.squeeze()
print(f"Depth map shape: {depth.shape}")
plt.figure(figsize=(10, 10))
plt.imshow(depth, cmap='gist_rainbow')
plt.colorbar(label='Depth [m]')
plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
plt.axis('off')
output_path = "depth_map.png"
plt.savefig(output_path)
plt.close()
raw_depth_path = "raw_depth_map.csv"
np.savetxt(raw_depth_path, depth.cpu().numpy() if isinstance(depth, torch.Tensor) else depth, delimiter=',')
print(f"Saved raw depth map to {raw_depth_path}")
focallength_px = float(focallength_px)
print(f"Converted focallength_px to float: {focallength_px}")
print("Depth map created!")
print(f"Returning - output_path: {output_path}, focallength_px: {focallength_px}, raw_depth_path: {raw_depth_path}, temp_file: {temp_file}")
return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, focallength_px
except Exception as e:
import traceback
error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
print(error_message)
return None, error_message, None, None
finally:
if temp_file and os.path.exists(temp_file):
os.remove(temp_file)
print(f"Removed temporary file: {temp_file}")
@spaces.GPU
def create_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold):
try:
depth = np.loadtxt(depth_csv, delimiter=',')
# Check if the image file exists
if not os.path.exists(image_path):
raise FileNotFoundError(f"Image file not found: {image_path}")
print(f"Loading image from: {image_path}")
view_model_path, download_model_path, texture_path = generate_3d_model(
depth, image_path, focallength_px,
simplification_factor, smoothing_iterations, thin_threshold
)
print("3D model generated!")
return view_model_path, download_model_path, texture_path, "3D model created successfully!"
except Exception as e:
error_message = f"An error occurred during 3D model creation: {str(e)}"
print(error_message)
return None, None, None, error_message
def get_last_commit_timestamp():
try:
# Get the timestamp in a format that includes timezone information
timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
# Parse the timestamp, including the timezone
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S %z")
# Convert to UTC
dt_utc = dt.astimezone(pytz.UTC)
# Format the date as desired
return dt_utc.strftime("%Y-%m-%d %H:%M:%S UTC")
except Exception as e:
print(f"Error getting last commit timestamp: {str(e)}")
return "Unknown"
# Create the Gradio interface with appropriate input and output components.
last_updated = get_last_commit_timestamp()
with gr.Blocks() as iface:
gr.Markdown("# DepthPro Demo with 3D Visualization")
gr.Markdown(
"An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
"Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
"**Instructions:**\n"
"1. Upload an image to generate the depth map.\n"
"2. Click 'Generate 3D Model' to create the 3D visualization.\n"
"3. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n"
"4. Download the raw depth data as a CSV file or the 3D model as an OBJ file if desired.\n\n"
f"Last updated: {last_updated}"
)
with gr.Row():
input_image = gr.Image(type="filepath", label="Input Image")
depth_map = gr.Image(type="filepath", label="Depth Map")
focal_length = gr.Textbox(label="Focal Length")
raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)")
generate_3d_button = gr.Button("Generate 3D Model")
with gr.Row():
view_3d_model = gr.Model3D(label="View 3D Model")
download_3d_model = gr.File(label="Download 3D Model (OBJ)")
download_texture = gr.File(label="Download Texture (PNG)")
with gr.Row():
simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.1, label="Simplification Factor (1.0 = No simplification)")
smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=0, step=1, label="Smoothing Iterations (0 = No smoothing)")
thin_threshold = gr.Slider(minimum=0, maximum=0.1, value=0, step=0.001, label="Thin Feature Threshold (0 = No thin feature removal)")
regenerate_button = gr.Button("Regenerate 3D Model")
model_status = gr.Textbox(label="3D Model Status")
# Hidden components to store intermediate results
hidden_focal_length = gr.State()
input_image.change(
predict_depth,
inputs=[input_image],
outputs=[depth_map, focal_length, raw_depth_csv, hidden_focal_length]
)
generate_3d_button.click(
create_3d_model,
inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold],
outputs=[view_3d_model, download_3d_model, download_texture, model_status]
)
regenerate_button.click(
create_3d_model,
inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold],
outputs=[view_3d_model, download_3d_model, download_texture, model_status]
)
# Launch the Gradio interface with sharing enabled
iface.launch(share=True) |