Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from PIL import Image | |
import src.depth_pro as depth_pro | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import subprocess | |
import spaces | |
import torch | |
import tempfile | |
import os | |
# Run the script to get pretrained models | |
subprocess.run(["bash", "get_pretrained_models.sh"]) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Load model and preprocessing transform | |
model, transform = depth_pro.create_model_and_transforms() | |
model = model.to(device) | |
model.eval() | |
def resize_image(image_path, max_size=1024): | |
with Image.open(image_path) as img: | |
# Calculate the new size while maintaining aspect ratio | |
ratio = max_size / max(img.size) | |
new_size = tuple([int(x * ratio) for x in img.size]) | |
# Resize the image | |
img = img.resize(new_size, Image.LANCZOS) | |
# Create a temporary file | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: | |
img.save(temp_file, format="PNG") | |
return temp_file.name | |
def predict_depth(input_image): | |
temp_file = None | |
try: | |
# Resize the input image | |
temp_file = resize_image(input_image) | |
# Preprocess the image | |
result = depth_pro.load_rgb(temp_file) | |
image = result[0] | |
f_px = result[-1] # Assuming f_px is the last item in the returned tuple | |
image = transform(image) | |
image = image.to(device) | |
# Run inference | |
prediction = model.infer(image, f_px=f_px) | |
depth = prediction["depth"] # Depth in [m] | |
focallength_px = prediction["focallength_px"] # Focal length in pixels | |
# Convert depth to numpy array if it's a torch tensor | |
if isinstance(depth, torch.Tensor): | |
depth = depth.cpu().numpy() | |
# Ensure depth is a 2D numpy array | |
if depth.ndim != 2: | |
depth = depth.squeeze() | |
# Normalize depth for visualization | |
depth_min = np.min(depth) | |
depth_max = np.max(depth) | |
depth_normalized = (depth - depth_min) / (depth_max - depth_min) | |
# Create a color map | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(depth_normalized, cmap='viridis') | |
plt.colorbar(label='Depth') | |
plt.title('Predicted Depth Map') | |
plt.axis('off') | |
# Save the plot to a file | |
output_path = "depth_map.png" | |
plt.savefig(output_path) | |
plt.close() | |
return output_path, f"Focal length: {focallength_px:.2f} pixels" | |
except Exception as e: | |
return None, f"An error occurred: {str(e)}" | |
finally: | |
# Clean up the temporary file | |
if temp_file and os.path.exists(temp_file): | |
os.remove(temp_file) | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_depth, | |
inputs=gr.Image(type="filepath"), | |
outputs=[gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message")], | |
title="Depth Prediction Demo", | |
description="Upload an image to predict its depth map and focal length. Large images will be automatically resized." | |
) | |
# Launch the interface | |
iface.launch() |