DepthPoseEstimation / depth_app.py
mkalia's picture
Update depth_app.py
22f2ad4 verified
raw
history blame
4.03 kB
import streamlit as st
import subprocess
import numpy as np
import PIL.Image as pil
import torch
import os
import io
from torchvision import transforms
import matplotlib as mpl
import matplotlib.cm as cm
from resnet_encoder import ResnetEncoder
from depth_decoder import DepthDecoder
# from pose_decoder import PoseDecoder
# from pose_cnn import PoseCNN
# def disp_to_depth_no_scaling(disp):
# """Convert network's sigmoid output into depth prediction
# The formula for this conversion is given in the 'additional considerations'
# section of the paper.
# """
# depth = 1 / (disp + 1e-7)
# return depth
# Function to load the model
def load_model(device, model_path):
encoder_path = os.path.join(model_path, "encoder.pth")
depth_decoder_path = os.path.join(model_path, "depth.pth")
encoder = ResnetEncoder(18, False)
loaded_dict_enc = torch.load(encoder_path, map_location=device)
feed_height = loaded_dict_enc['height']
feed_width = loaded_dict_enc['width']
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
encoder.load_state_dict(filtered_dict_enc)
encoder.to(device)
encoder.eval()
depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
loaded_dict = torch.load(depth_decoder_path, map_location=device)
depth_decoder.load_state_dict(loaded_dict, strict=False)
depth_decoder.to(device)
depth_decoder.eval()
return encoder, depth_decoder, feed_height, feed_width
# Function to apply the model to an image
def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width):
input_image = image.resize((feed_width, feed_height), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
features = encoder(input_image)
outputs = depth_decoder(features)
disp = outputs[("disp", 0)]
disp_resized = disp
disp_resized_np = disp_resized.squeeze().cpu().numpy()
vmax = np.percentile(disp_resized_np, 95)
normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
return colormapped_im
# Streamlit app
def main():
st.title("Monodepth2 Depth Estimation")
st.write("Upload a PNG image to get depth estimation")
if not os.path.exists("depth.pth"):
subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth"])
if not os.path.exists("encoder.pth"):
subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth"])
uploaded_file = st.file_uploader("Choose a PNG file", type="png")
if uploaded_file is not None:
image = pil.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Predict Depth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = '.' # Update this to the path of your model
encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path)
colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width)
depth_image = pil.fromarray(colormapped_im)
st.image(depth_image, caption='Predicted Depth Image', use_column_width=True)
# Convert depth image to bytes
img_bytes = io.BytesIO()
depth_image.save(img_bytes, format='PNG')
img_bytes = img_bytes.getvalue()
# Provide download link
st.download_button(
label="Download Depth Image",
data=img_bytes,
file_name="depth_image.png",
mime="image/png"
)
if __name__ == "__main__":
main()