| import torch |
| import gradio as gr |
| import numpy as np |
| from PIL import Image |
| import torchvision.transforms as T |
| from model import DepthSTAR |
| import matplotlib.cm as cm |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = DepthSTAR().to(device) |
| model.load_state_dict(torch.load("depth_model_all.pth", map_location=device)) |
| model.eval() |
|
|
| |
| transform = T.Compose([ |
| T.Resize((32, 32)), |
| T.ToTensor(), |
| ]) |
|
|
| |
| image_output_size = 256 |
|
|
| def predict_depth(image): |
| img = transform(image).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| pred = model(img)[0, 0].cpu().numpy() |
|
|
| |
| pred_normalized = (pred - pred.min()) / (pred.max() - pred.min() + 1e-8) |
|
|
| |
| cmap = cm.get_cmap('inferno') |
| pred_colored = cmap(pred_normalized)[:, :, :3] |
| pred_colored = (pred_colored * 255).astype(np.uint8) |
| pred_pil = Image.fromarray(pred_colored) |
|
|
| |
| upscale_size = (256, 256) |
| image_resized = image.resize(upscale_size, resample=Image.NEAREST) |
| pred_resized = pred_pil.resize(upscale_size, resample=Image.NEAREST) |
|
|
| return [image_resized, pred_resized] |
|
|
|
|
| |
| examples = [["img_000.png"],["img_001.png"],["img_002.png"],["img_003.png"],["img_004.png"],["img_005.png"],] |
| |
|
|
| demo = gr.Interface( |
| fn=predict_depth, |
| inputs=gr.Image(type="pil", label="Input RGB Image", height=image_output_size), |
| outputs=[ |
| gr.Image(type="pil", label="Original Image", height=image_output_size), |
| gr.Image(type="pil", label="Predicted Depth Map", height=image_output_size), |
| ], |
| title="🔭 DepthStar: Light-weight Depth Estimation", |
| description="Upload an RGB image and get the depth map predicted by our tiny DepthStar model.", |
| examples=examples, |
| theme="darkdefault", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |