Spaces:
Running
Running
import gradio as gr | |
import torch | |
from models.pretrained_decv2 import enc_dec_model | |
from models.densenet_v2 import Densenet | |
from models.unet_resnet18 import ResNet18UNet | |
from models.unet_resnet50 import UNetWithResnet50Encoder | |
import numpy as np | |
import cv2 | |
# kb cropping | |
def cropping(img): | |
h_im, w_im = img.shape[:2] | |
margin_top = int(h_im - 352) | |
margin_left = int((w_im - 1216) / 2) | |
img = img[margin_top: margin_top + 352, | |
margin_left: margin_left + 1216] | |
return img | |
def load_model(ckpt, model, optimizer=None): | |
ckpt_dict = torch.load(ckpt, map_location='cpu') | |
# keep backward compatibility | |
if 'model' not in ckpt_dict and 'optimizer' not in ckpt_dict: | |
state_dict = ckpt_dict | |
else: | |
state_dict = ckpt_dict['model'] | |
weights = {} | |
for key, value in state_dict.items(): | |
if key.startswith('module.'): | |
weights[key[len('module.'):]] = value | |
else: | |
weights[key] = value | |
model.load_state_dict(weights) | |
if optimizer is not None: | |
optimizer_state = ckpt_dict['optimizer'] | |
optimizer.load_state_dict(optimizer_state) | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' | |
print(DEVICE) | |
CWD = "." | |
CKPT_FILE_NAMES = { | |
'Indoor':{ | |
'Resnet_enc':'resnet_nyu_best.ckpt', | |
'Unet':'resnet18_unet_epoch_08_model_kitti_and_nyu.ckpt', | |
'Densenet_enc':'densenet_epoch_15_model.ckpt' | |
}, | |
'Outdoor':{ | |
'Resnet_enc':'resnet_encdecmodel_epoch_05_model_nyu_and_kitti.ckpt', | |
'Unet':'resnet50_unet_epoch_02_model_nyuandkitti.ckpt', | |
'Densenet_enc':'densenet_nyu_then_kitti_epoch_10_model.ckpt' | |
} | |
} | |
MODEL_CLASSES = { | |
'Indoor': { | |
'Resnet_enc':enc_dec_model(max_depth = 10), | |
'Unet':ResNet18UNet(max_depth = 10), | |
'Densenet_enc':Densenet(max_depth = 10) | |
}, | |
'Outdoor': { | |
'Resnet_enc':enc_dec_model(max_depth = 80), | |
'Unet':UNetWithResnet50Encoder(max_depth = 80), | |
'Densenet_enc':Densenet(max_depth = 80) | |
}, | |
} | |
location_types = ['Indoor', 'Outdoor'] | |
Models = ['Resnet_enc','Unet','Densenet_enc'] | |
for location in location_types: | |
for model in Models: | |
ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model]}" | |
load_model(ckpt_dir, MODEL_CLASSES[location][model]) | |
def predict(location, model_name, img): | |
# ckpt_dir = f"{CWD}/ckpt/{CKPT_FILE_NAMES[location][model_name]}" | |
# if location == 'nyu': | |
# max_depth = 10 | |
# else: | |
# max_depth = 80 | |
# model = MODEL_CLASSES[location][model_name](max_depth).to(DEVICE) | |
model = MODEL_CLASSES[location][model_name].to(DEVICE) | |
# load_model(ckpt_dir,model) | |
# print(img.shape) | |
# assert False | |
if img.shape == (375,1242,3): | |
img = cropping(img) | |
img = torch.tensor(img).permute(2, 0, 1).float().to(DEVICE) | |
input_RGB = img.unsqueeze(0) | |
print(input_RGB.shape) | |
with torch.no_grad(): | |
pred = model(input_RGB) | |
pred_d = pred['pred_d'] | |
pred_d_numpy = pred_d.squeeze().cpu().numpy() | |
# pred_d_numpy = (pred_d_numpy - pred_d_numpy.mean())/pred_d_numpy.std() | |
pred_d_numpy = np.clip((pred_d_numpy / pred_d_numpy[15:,:].max()) * 255, 0,255) | |
# pred_d_numpy = (pred_d_numpy / pred_d_numpy.max()) * 255 | |
pred_d_numpy = pred_d_numpy.astype(np.uint8) | |
pred_d_color = cv2.applyColorMap(pred_d_numpy, cv2.COLORMAP_RAINBOW) | |
pred_d_color = cv2.cvtColor(pred_d_color, cv2.COLOR_BGR2RGB) | |
# del model | |
return pred_d_color | |
with gr.Blocks() as demo: | |
gr.Markdown("# Monocular Depth Estimation") | |
with gr.Row(): | |
location = gr.Radio(choices=['Indoor', 'Outdoor'],value='Indoor', label = "Select Location Type") | |
model_name = gr.Radio(['Unet', 'Resnet_enc', 'Densenet_enc'],value="Densenet_enc" ,label="Select model") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label = "Input Image for Depth Estimation") | |
with gr.Column(): | |
output_depth_map = gr.Image(label = "Depth prediction Heatmap") | |
with gr.Row(): | |
predict_btn = gr.Button("Generate Depthmap") | |
predict_btn.click(fn=predict, inputs=[location, model_name, input_image], outputs=output_depth_map) | |
with gr.Row(): | |
gr.Examples(['./demo_data/Bathroom.jpg', './demo_data/Bedroom.jpg', './demo_data/Bookstore.jpg', './demo_data/Classroom.jpg', './demo_data/Computerlab.jpg', './demo_data/kitti_1.png'], inputs=input_image) | |
demo.launch() |