DepthPoseEstimation / depth_app.py
mkalia's picture
Update depth_app.py
33485a1 verified
raw
history blame
4.32 kB
import streamlit as st
import numpy as np
import PIL.Image as pil
import torch
import os
import io
from torchvision import transforms
import matplotlib as mpl
import matplotlib.cm as cm
from resnet_encoder import ResnetEncoder
from depth_decoder import DepthDecoder
import os
import subprocess
import requests
def download_model_files():
depth_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth"
encoder_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth"
if not os.path.exists("depth.pth"):
response = requests.get(depth_url)
with open("depth.pth", "wb") as f:
f.write(response.content)
if not os.path.exists("encoder.pth"):
response = requests.get(encoder_url)
with open("encoder.pth", "wb") as f:
f.write(response.content)
# from pose_decoder import PoseDecoder
# from pose_cnn import PoseCNN
# def disp_to_depth_no_scaling(disp):
# """Convert network's sigmoid output into depth prediction
# The formula for this conversion is given in the 'additional considerations'
# section of the paper.
# """
# depth = 1 / (disp + 1e-7)
# return depth
# Function to load the model
def load_model(device, model_path):
encoder_path = os.path.join(model_path, "encoder.pth")
depth_decoder_path = os.path.join(model_path, "depth.pth")
encoder = ResnetEncoder(18, False)
loaded_dict_enc = torch.load(encoder_path, map_location=device)
feed_height = loaded_dict_enc['height']
feed_width = loaded_dict_enc['width']
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
encoder.load_state_dict(filtered_dict_enc)
encoder.to(device)
encoder.eval()
depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
loaded_dict = torch.load(depth_decoder_path, map_location=device)
depth_decoder.load_state_dict(loaded_dict, strict=False)
depth_decoder.to(device)
depth_decoder.eval()
return encoder, depth_decoder, feed_height, feed_width
# Function to apply the model to an image
def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width):
input_image = image.resize((feed_width, feed_height), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
features = encoder(input_image)
outputs = depth_decoder(features)
disp = outputs[("disp", 0)]
disp_resized = disp
disp_resized_np = disp_resized.squeeze().cpu().numpy()
vmax = np.percentile(disp_resized_np, 95)
normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
return colormapped_im
# Streamlit app
def main():
st.title("Monodepth2 Depth Estimation")
st.write("Upload a PNG image to get depth estimation")
download_model_files()
uploaded_file = st.file_uploader("Choose a PNG file", type="png")
if uploaded_file is not None:
image = pil.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Predict Depth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = '.' # Update this to the path of your model
encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path)
colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width)
depth_image = pil.fromarray(colormapped_im)
st.image(depth_image, caption='Predicted Depth Image', use_column_width=True)
# Convert depth image to bytes
img_bytes = io.BytesIO()
depth_image.save(img_bytes, format='PNG')
img_bytes = img_bytes.getvalue()
# Provide download link
st.download_button(
label="Download Depth Image",
data=img_bytes,
file_name="depth_image.png",
mime="image/png"
)
if __name__ == "__main__":
main()