File size: 1,534 Bytes
b19928f
 
a394b1d
b19928f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt

# Load model and preprocessing transform
model, transform = depth_pro.create_model_and_transforms()
model.eval()

def predict_depth(input_image):
    # Preprocess the image
    result = depth_pro.load_rgb(input_image.name)
    image = result[0]
    f_px = result[-1]  # Assuming f_px is the last item in the returned tuple
    image = transform(image)

    # Run inference
    prediction = model.infer(image, f_px=f_px)
    depth = prediction["depth"]  # Depth in [m]
    focallength_px = prediction["focallength_px"]  # Focal length in pixels

    # Normalize depth for visualization
    depth_normalized = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
    
    # Create a color map
    plt.figure(figsize=(10, 10))
    plt.imshow(depth_normalized, cmap='viridis')
    plt.colorbar(label='Depth')
    plt.title('Predicted Depth Map')
    plt.axis('off')
    
    # Save the plot to a file
    output_path = "depth_map.png"
    plt.savefig(output_path)
    plt.close()

    return output_path, f"Focal length: {focallength_px:.2f} pixels"

# Create Gradio interface
iface = gr.Interface(
    fn=predict_depth,
    inputs=gr.Image(type="filepath"),
    outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length")],
    title="Depth Prediction Demo",
    description="Upload an image to predict its depth map and focal length."
)

# Launch the interface
iface.launch()