File size: 3,196 Bytes
b19928f
 
a394b1d
b19928f
 
d6c2352
 
2e549d0
bdf07c0
 
5a3dc03
e9d914e
d6c2352
b19928f
713dd5d
 
b19928f
 
713dd5d
b19928f
 
26691a8
 
 
 
 
 
 
 
 
bdf07c0
 
 
 
26691a8
713dd5d
b19928f
bdf07c0
2e549d0
26691a8
bdf07c0
26691a8
2e549d0
bdf07c0
2e549d0
 
 
713dd5d
2e549d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdf07c0
 
 
 
b19928f
 
 
 
 
2e549d0
b19928f
26691a8
b19928f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os

# Run the script to get pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)
model.eval()

def resize_image(image_path, max_size=1024):
    with Image.open(image_path) as img:
        # Calculate the new size while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image
        img = img.resize(new_size, Image.LANCZOS)
        
        # Create a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

@spaces.GPU(duration=20)
def predict_depth(input_image):
    temp_file = None
    try:
        # Resize the input image
        temp_file = resize_image(input_image)
        
        # Preprocess the image
        result = depth_pro.load_rgb(temp_file)
        image = result[0]
        f_px = result[-1]  # Assuming f_px is the last item in the returned tuple
        image = transform(image)
        image = image.to(device)

        # Run inference
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth in [m]
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth to numpy array if it's a torch tensor
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure depth is a 2D numpy array
        if depth.ndim != 2:
            depth = depth.squeeze()

        # Normalize depth for visualization
        depth_min = np.min(depth)
        depth_max = np.max(depth)
        depth_normalized = (depth - depth_min) / (depth_max - depth_min)
        
        # Create a color map
        plt.figure(figsize=(10, 10))
        plt.imshow(depth_normalized, cmap='viridis')
        plt.colorbar(label='Depth')
        plt.title('Predicted Depth Map')
        plt.axis('off')
        
        # Save the plot to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        return output_path, f"Focal length: {focallength_px:.2f} pixels"
    except Exception as e:
        return None, f"An error occurred: {str(e)}"
    finally:
        # Clean up the temporary file
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

# Create Gradio interface
iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message")],
    title="Depth Prediction Demo",
    description="Upload an image to predict its depth map and focal length. Large images will be automatically resized."
)

# Launch the interface
iface.launch()