File size: 14,320 Bytes
47669a5
cf8c487
 
 
 
dc9d69c
ee04d83
 
f95c546
40c89eb
47669a5
6648275
47669a5
 
 
 
39f1439
 
40c89eb
 
024a2b8
dc9d69c
cf8c487
3f2f727
39f1439
4322221
39f1439
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764b436
39f1439
cf8c487
501d06f
f95c546
 
 
 
 
 
 
 
 
 
 
501d06f
f95c546
501d06f
 
 
f95c546
501d06f
 
f95c546
ee04d83
 
 
501d06f
423abd1
9194204
f95c546
 
 
7ce37c4
 
 
6648275
 
7ce37c4
 
 
 
 
 
6648275
 
7ce37c4
 
 
 
6648275
7ce37c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6648275
7ce37c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f49612a
7ce37c4
9194204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce37c4
 
 
 
 
6648275
7ce37c4
6648275
7ce37c4
6648275
 
 
 
 
 
7ce37c4
6648275
7ce37c4
 
 
f95c546
2b69b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423abd1
f9c3dad
 
 
 
 
6648275
f9c3dad
 
 
83e6e59
6648275
f9c3dad
3f2f727
cf8c487
ee04d83
404967f
29a026e
 
 
ee04d83
29a026e
501d06f
39f1439
4322221
404967f
 
 
 
b3b839e
f95c546
404967f
b3b839e
182cf21
b3b839e
39f1439
f95c546
404967f
 
 
 
182cf21
4322221
 
 
 
 
182cf21
3f2f727
4322221
9890878
404967f
29a026e
 
39f1439
9890878
ee04d83
 
 
4322221
cf8c487
423abd1
3f2f727
 
 
ced7c47
 
 
 
 
 
 
6648275
3f2f727
 
 
 
6648275
3f2f727
 
 
6648275
3f2f727
c6f3d95
 
6648275
c6f3d95
6648275
 
 
 
 
 
 
 
 
b3b839e
6648275
 
c6f3d95
d893f72
4bfe855
 
f9c3dad
 
 
f95c546
4bfe855
f95c546
3f2f727
 
 
 
4bfe855
f9c3dad
 
 
 
 
 
 
 
 
3f2f727
 
f9c3dad
 
 
6648275
f9c3dad
 
9194204
 
 
f9c3dad
 
3f2f727
f9c3dad
 
 
 
 
9890878
f9c3dad
9890878
3f2f727
 
 
 
9890878
6648275
f9c3dad
 
 
3f2f727
9890878
6648275
f9c3dad
cf8c487
f95c546
f9c3dad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import spaces
import gradio as gr
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import tempfile
import os
import trimesh
import time
from datetime import datetime
import pytz

# Import potentially CUDA-initializing modules after 'spaces'
import torch
import src.depth_pro as depth_pro
import timm
import cv2

print(f"Timm version: {timm.__version__}")

subprocess.run(["bash", "get_pretrained_models.sh"])

@spaces.GPU(duration=30)
def load_model_and_predict(image_path):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model, transform = depth_pro.create_model_and_transforms()
    model = model.to(device)
    model.eval()

    result = depth_pro.load_rgb(image_path)
    if len(result) < 2:
        raise ValueError(f"Unexpected result from load_rgb: {result}")
    
    image = result[0]
    f_px = result[-1]
    print(f"Extracted focal length: {f_px}")
    
    image = transform(image).to(device)

    with torch.no_grad():
        prediction = model.infer(image, f_px=f_px)
    
    depth = prediction["depth"].cpu().numpy()
    focallength_px = prediction["focallength_px"]

    return depth, focallength_px

def resize_image(image_path, max_size=1024):
    """
    Resize the input image to ensure its largest dimension does not exceed max_size.
    Maintains the aspect ratio and saves the resized image as a temporary PNG file.

    Args:
        image_path (str): Path to the input image.
        max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.

    Returns:
        str: Path to the resized temporary image file.
    """
    with Image.open(image_path) as img:
        # Calculate the resizing ratio while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image using LANCZOS filter for high-quality downsampling
        img = img.resize(new_size, Image.LANCZOS)
        
        # Save the resized image to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

@spaces.GPU  # Increased duration to default 60 seconds
def generate_3d_model(depth, image_path, focallength_px, simplification_factor=1.0, smoothing_iterations=0, thin_threshold=0):
    """
    Generate a textured 3D mesh from the depth map and the original image.
    """
    try:
        print("Starting 3D model generation")
        # Load the RGB image and convert to a NumPy array
        image = Image.open(image_path)
        image_array = np.array(image)
        
        # Ensure depth is a NumPy array
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()
        
        # Resize depth to match image dimensions if necessary
        if depth.shape != image_array.shape[:2]:
            depth = cv2.resize(depth, (image_array.shape[1], image_array.shape[0]), interpolation=cv2.INTER_LINEAR)
        
        height, width = depth.shape

        print(f"3D model generation - Depth shape: {depth.shape}")
        print(f"3D model generation - Image shape: {image_array.shape}")

        # Compute camera intrinsic parameters
        fx = fy = float(focallength_px)  # Ensure focallength_px is a float
        cx, cy = width / 2, height / 2  # Principal point at the image center

        # Create a grid of (u, v) pixel coordinates
        u = np.arange(0, width)
        v = np.arange(0, height)
        uu, vv = np.meshgrid(u, v)

        # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
        Z = depth.flatten()
        X = ((uu.flatten() - cx) * Z) / fx
        Y = ((vv.flatten() - cy) * Z) / fy

        # Stack the coordinates to form vertices (X, Y, Z)
        vertices = np.vstack((X, Y, Z)).T

        # Normalize RGB colors to [0, 1] for vertex coloring
        colors = image_array.reshape(-1, 3) / 255.0

        print("Generating faces")
        # Generate faces by connecting adjacent vertices to form triangles
        faces = []
        for i in range(height - 1):
            for j in range(width - 1):
                idx = i * width + j
                # Triangle 1
                faces.append([idx, idx + width, idx + 1])
                # Triangle 2
                faces.append([idx + 1, idx + width, idx + width + 1])
        faces = np.array(faces)

        print("Creating mesh")
        # Create the mesh using Trimesh with vertex colors
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors, process=False)

        # Mesh cleaning and improvement steps (only if not using default values)
        if simplification_factor < 1.0 or smoothing_iterations > 0 or thin_threshold > 0:
            print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

            if simplification_factor < 1.0:
                print("Simplifying mesh")
                target_faces = int(len(mesh.faces) * simplification_factor)
                mesh = mesh.simplify_quadric_decimation(face_count=target_faces)
                print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

            if smoothing_iterations > 0:
                print("Smoothing mesh")
                for _ in range(smoothing_iterations):
                    mesh = mesh.smoothed()
                print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

            if thin_threshold > 0:
                print("Removing thin features")
                mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold)
                print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

        # Export the mesh to OBJ files with unique filenames
        timestamp = int(time.time())
        view_model_path = f'view_model_{timestamp}.obj'
        download_model_path = f'download_model_{timestamp}.obj'
        
        print("Exporting to view")
        mesh.export(view_model_path, include_texture=True)
        print("Exporting to download")
        mesh.export(download_model_path, include_texture=True)
        
        # Save the texture image
        texture_path = f'texture_{timestamp}.png'
        image.save(texture_path)
        
        print("Export completed")
        return view_model_path, download_model_path, texture_path
    except Exception as e:
        print(f"Error in generate_3d_model: {str(e)}")
        raise

def remove_thin_features(mesh, thickness_threshold=0.01):
    """
    Remove thin features from the mesh.
    """
    # Calculate edge lengths
    edges = mesh.edges_unique
    edge_points = mesh.vertices[edges]
    edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1)
    
    # Identify short edges
    short_edges = edges[edge_lengths < thickness_threshold]
    
    # Collapse short edges
    for edge in short_edges:
        try:
            mesh.collapse_edge(edge)
        except:
            pass  # Skip if edge collapse fails
    
    # Remove any newly created degenerate faces
    mesh.remove_degenerate_faces()
    
    return mesh

@spaces.GPU  # Increased duration to default 60 seconds
def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold):
    # Load depth from CSV
    depth = np.loadtxt(depth_csv, delimiter=',')
    
    # Generate new 3D model with updated parameters
    view_model_path, download_model_path, texture_path = generate_3d_model(
        depth, image_path, focallength_px, 
        simplification_factor, smoothing_iterations, thin_threshold
    )
    print("regenerated!")
    return view_model_path, download_model_path, texture_path

@spaces.GPU(duration=30)
def predict_depth(input_image):
    temp_file = None
    try:
        print(f"Input image type: {type(input_image)}")
        print(f"Input image path: {input_image}")
        
        temp_file = resize_image(input_image)
        print(f"Resized image path: {temp_file}")
        
        depth, focallength_px = load_model_and_predict(temp_file)
        print(f"Raw depth type: {type(depth)}, focallength_px type: {type(focallength_px)}")

        if depth.ndim != 2:
            depth = depth.squeeze()

        print(f"Depth map shape: {depth.shape}")

        plt.figure(figsize=(10, 10))
        plt.imshow(depth, cmap='gist_rainbow')
        plt.colorbar(label='Depth [m]')
        plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
        plt.axis('off')

        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        raw_depth_path = "raw_depth_map.csv"
        np.savetxt(raw_depth_path, depth.cpu().numpy() if isinstance(depth, torch.Tensor) else depth, delimiter=',')
        print(f"Saved raw depth map to {raw_depth_path}")

        focallength_px = float(focallength_px)
        print(f"Converted focallength_px to float: {focallength_px}")

        print("Depth map created!")
        print(f"Returning - output_path: {output_path}, focallength_px: {focallength_px}, raw_depth_path: {raw_depth_path}, temp_file: {temp_file}")
        return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, focallength_px
    except Exception as e:
        import traceback
        error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_message)
        return None, error_message, None, None
    finally:
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)
            print(f"Removed temporary file: {temp_file}")

@spaces.GPU
def create_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold):
    try:
        depth = np.loadtxt(depth_csv, delimiter=',')
        
        # Check if the image file exists
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image file not found: {image_path}")
        
        print(f"Loading image from: {image_path}")
        
        view_model_path, download_model_path, texture_path = generate_3d_model(
            depth, image_path, focallength_px, 
            simplification_factor, smoothing_iterations, thin_threshold
        )
        print("3D model generated!")
        return view_model_path, download_model_path, texture_path, "3D model created successfully!"
    except Exception as e:
        error_message = f"An error occurred during 3D model creation: {str(e)}"
        print(error_message)
        return None, None, None, error_message

def get_last_commit_timestamp():
    try:
        # Get the timestamp in a format that includes timezone information
        timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
        
        # Parse the timestamp, including the timezone
        dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S %z")
        
        # Convert to UTC
        dt_utc = dt.astimezone(pytz.UTC)
        
        # Format the date as desired
        return dt_utc.strftime("%Y-%m-%d %H:%M:%S UTC")
    except Exception as e:
        print(f"Error getting last commit timestamp: {str(e)}")
        return "Unknown"
    
# Create the Gradio interface with appropriate input and output components. 
last_updated = get_last_commit_timestamp()

with gr.Blocks() as iface:
    gr.Markdown("# DepthPro Demo with 3D Visualization")
    gr.Markdown(
        "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
        "Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
        "**Instructions:**\n"
        "1. Upload an image to generate the depth map.\n"
        "2. Click 'Generate 3D Model' to create the 3D visualization.\n"
        "3. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n"
        "4. Download the raw depth data as a CSV file or the 3D model as an OBJ file if desired.\n\n"
        f"Last updated: {last_updated}"
    )
    
    with gr.Row():
        input_image = gr.Image(type="filepath", label="Input Image")
        depth_map = gr.Image(type="filepath", label="Depth Map")
    
    focal_length = gr.Textbox(label="Focal Length")
    raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)")
    
    generate_3d_button = gr.Button("Generate 3D Model")
    
    with gr.Row():
        view_3d_model = gr.Model3D(label="View 3D Model")
        download_3d_model = gr.File(label="Download 3D Model (OBJ)")
        download_texture = gr.File(label="Download Texture (PNG)")
    
    with gr.Row():
        simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=1.0, step=0.1, label="Simplification Factor (1.0 = No simplification)")
        smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=0, step=1, label="Smoothing Iterations (0 = No smoothing)")
        thin_threshold = gr.Slider(minimum=0, maximum=0.1, value=0, step=0.001, label="Thin Feature Threshold (0 = No thin feature removal)")
    
    regenerate_button = gr.Button("Regenerate 3D Model")
    model_status = gr.Textbox(label="3D Model Status")
    
    # Hidden components to store intermediate results
    hidden_focal_length = gr.State()
    
    input_image.change(
        predict_depth,
        inputs=[input_image],
        outputs=[depth_map, focal_length, raw_depth_csv, hidden_focal_length]
    )
    
    generate_3d_button.click(
        create_3d_model,
        inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold],
        outputs=[view_3d_model, download_3d_model, download_texture, model_status]
    )
    
    regenerate_button.click(
        create_3d_model,
        inputs=[raw_depth_csv, input_image, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold],
        outputs=[view_3d_model, download_3d_model, download_texture, model_status]
    )

# Launch the Gradio interface with sharing enabled
iface.launch(share=True)