Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import PIL.Image as pil | |
import torch | |
import os | |
import io | |
from torchvision import transforms | |
import matplotlib as mpl | |
import matplotlib.cm as cm | |
from resnet_encoder import ResnetEncoder | |
from depth_decoder import DepthDecoder | |
import os | |
import subprocess | |
import requests | |
def download_model_files(): | |
depth_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth" | |
encoder_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth" | |
if not os.path.exists("depth.pth"): | |
response = requests.get(depth_url) | |
with open("depth.pth", "wb") as f: | |
f.write(response.content) | |
if not os.path.exists("encoder.pth"): | |
response = requests.get(encoder_url) | |
with open("encoder.pth", "wb") as f: | |
f.write(response.content) | |
# from pose_decoder import PoseDecoder | |
# from pose_cnn import PoseCNN | |
# def disp_to_depth_no_scaling(disp): | |
# """Convert network's sigmoid output into depth prediction | |
# The formula for this conversion is given in the 'additional considerations' | |
# section of the paper. | |
# """ | |
# depth = 1 / (disp + 1e-7) | |
# return depth | |
# Function to load the model | |
def load_model(device, model_path): | |
encoder_path = os.path.join(model_path, "encoder.pth") | |
depth_decoder_path = os.path.join(model_path, "depth.pth") | |
encoder = ResnetEncoder(18, False) | |
loaded_dict_enc = torch.load(encoder_path, map_location=device) | |
feed_height = loaded_dict_enc['height'] | |
feed_width = loaded_dict_enc['width'] | |
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()} | |
encoder.load_state_dict(filtered_dict_enc) | |
encoder.to(device) | |
encoder.eval() | |
depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4)) | |
loaded_dict = torch.load(depth_decoder_path, map_location=device) | |
depth_decoder.load_state_dict(loaded_dict, strict=False) | |
depth_decoder.to(device) | |
depth_decoder.eval() | |
return encoder, depth_decoder, feed_height, feed_width | |
# Function to apply the model to an image | |
def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width): | |
input_image = image.resize((feed_width, feed_height), pil.LANCZOS) | |
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device) | |
features = encoder(input_image) | |
outputs = depth_decoder(features) | |
disp = outputs[("disp", 0)] | |
disp_resized = disp | |
disp_resized_np = disp_resized.squeeze().cpu().numpy() | |
vmax = np.percentile(disp_resized_np, 95) | |
normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) | |
mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') | |
colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) | |
return colormapped_im | |
# Streamlit app | |
def main(): | |
st.title("Monodepth2 Depth Estimation") | |
st.write("Upload a PNG image to get depth estimation") | |
download_model_files() | |
uploaded_file = st.file_uploader("Choose a PNG file", type="png") | |
if uploaded_file is not None: | |
image = pil.open(uploaded_file).convert('RGB') | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
if st.button('Predict Depth'): | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model_path = '.' # Update this to the path of your model | |
encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path) | |
colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width) | |
depth_image = pil.fromarray(colormapped_im) | |
st.image(depth_image, caption='Predicted Depth Image', use_column_width=True) | |
# Convert depth image to bytes | |
img_bytes = io.BytesIO() | |
depth_image.save(img_bytes, format='PNG') | |
img_bytes = img_bytes.getvalue() | |
# Provide download link | |
st.download_button( | |
label="Download Depth Image", | |
data=img_bytes, | |
file_name="depth_image.png", | |
mime="image/png" | |
) | |
if __name__ == "__main__": | |
main() | |