DepthPoseEstimation / depth_app.py
mkalia's picture
first code
e86792f verified
raw
history blame
3.66 kB
import streamlit as st
import numpy as np
import PIL.Image as pil
import torch
import os
import io
from torchvision import transforms
import matplotlib as mpl
import matplotlib.cm as cm
import networks
from layers import disp_to_depth_no_scaling
# Function to load the model
def load_model(device, model_path):
encoder_path = os.path.join(model_path, "encoder.pth")
depth_decoder_path = os.path.join(model_path, "depth.pth")
encoder = networks.ResnetEncoder(18, False)
loaded_dict_enc = torch.load(encoder_path, map_location=device)
feed_height = loaded_dict_enc['height']
feed_width = loaded_dict_enc['width']
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
encoder.load_state_dict(filtered_dict_enc)
encoder.to(device)
encoder.eval()
depth_decoder = networks.DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
loaded_dict = torch.load(depth_decoder_path, map_location=device)
depth_decoder.load_state_dict(loaded_dict, strict=False)
depth_decoder.to(device)
depth_decoder.eval()
return encoder, depth_decoder, feed_height, feed_width
# Function to apply the model to an image
def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width):
input_image = image.resize((feed_width, feed_height), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
features = encoder(input_image)
outputs = depth_decoder(features)
disp = outputs[("disp", 0)]
disp_resized = disp
disp_resized_np = disp_resized.squeeze().cpu().numpy()
vmax = np.percentile(disp_resized_np, 95)
normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
return colormapped_im
# Streamlit app
def main():
st.title("Monodepth2 Depth Estimation")
st.write("Upload a PNG image to get depth estimation")
if not os.path.exits("depth.pth"):
subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth"])
if not os.path.exits("encoder.pth"):
subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth"])
uploaded_file = st.file_uploader("Choose a PNG file", type="png")
if uploaded_file is not None:
image = pil.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Predict Depth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = '.' # Update this to the path of your model
encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path)
colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width)
depth_image = pil.fromarray(colormapped_im)
st.image(depth_image, caption='Predicted Depth Image', use_column_width=True)
# Convert depth image to bytes
img_bytes = io.BytesIO()
depth_image.save(img_bytes, format='PNG')
img_bytes = img_bytes.getvalue()
# Provide download link
st.download_button(
label="Download Depth Image",
data=img_bytes,
file_name="depth_image.png",
mime="image/png"
)
if __name__ == "__main__":
main()