Spaces:
Sleeping
Sleeping
File size: 4,323 Bytes
e86792f d4950f4 73180f7 33485a1 73180f7 e86792f d4950f4 e86792f 73180f7 e86792f 73180f7 e86792f 33485a1 e86792f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import streamlit as st
import numpy as np
import PIL.Image as pil
import torch
import os
import io
from torchvision import transforms
import matplotlib as mpl
import matplotlib.cm as cm
from resnet_encoder import ResnetEncoder
from depth_decoder import DepthDecoder
import os
import subprocess
import requests
def download_model_files():
depth_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth"
encoder_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth"
if not os.path.exists("depth.pth"):
response = requests.get(depth_url)
with open("depth.pth", "wb") as f:
f.write(response.content)
if not os.path.exists("encoder.pth"):
response = requests.get(encoder_url)
with open("encoder.pth", "wb") as f:
f.write(response.content)
# from pose_decoder import PoseDecoder
# from pose_cnn import PoseCNN
# def disp_to_depth_no_scaling(disp):
# """Convert network's sigmoid output into depth prediction
# The formula for this conversion is given in the 'additional considerations'
# section of the paper.
# """
# depth = 1 / (disp + 1e-7)
# return depth
# Function to load the model
def load_model(device, model_path):
encoder_path = os.path.join(model_path, "encoder.pth")
depth_decoder_path = os.path.join(model_path, "depth.pth")
encoder = ResnetEncoder(18, False)
loaded_dict_enc = torch.load(encoder_path, map_location=device)
feed_height = loaded_dict_enc['height']
feed_width = loaded_dict_enc['width']
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
encoder.load_state_dict(filtered_dict_enc)
encoder.to(device)
encoder.eval()
depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
loaded_dict = torch.load(depth_decoder_path, map_location=device)
depth_decoder.load_state_dict(loaded_dict, strict=False)
depth_decoder.to(device)
depth_decoder.eval()
return encoder, depth_decoder, feed_height, feed_width
# Function to apply the model to an image
def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width):
input_image = image.resize((feed_width, feed_height), pil.LANCZOS)
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
features = encoder(input_image)
outputs = depth_decoder(features)
disp = outputs[("disp", 0)]
disp_resized = disp
disp_resized_np = disp_resized.squeeze().cpu().numpy()
vmax = np.percentile(disp_resized_np, 95)
normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
return colormapped_im
# Streamlit app
def main():
st.title("Monodepth2 Depth Estimation")
st.write("Upload a PNG image to get depth estimation")
download_model_files()
uploaded_file = st.file_uploader("Choose a PNG file", type="png")
if uploaded_file is not None:
image = pil.open(uploaded_file).convert('RGB')
st.image(image, caption='Uploaded Image', use_column_width=True)
if st.button('Predict Depth'):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = '.' # Update this to the path of your model
encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path)
colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width)
depth_image = pil.fromarray(colormapped_im)
st.image(depth_image, caption='Predicted Depth Image', use_column_width=True)
# Convert depth image to bytes
img_bytes = io.BytesIO()
depth_image.save(img_bytes, format='PNG')
img_bytes = img_bytes.getvalue()
# Provide download link
st.download_button(
label="Download Depth Image",
data=img_bytes,
file_name="depth_image.png",
mime="image/png"
)
if __name__ == "__main__":
main()
|