import streamlit as st import subprocess import numpy as np import PIL.Image as pil import torch import os import io from torchvision import transforms import matplotlib as mpl import matplotlib.cm as cm from resnet_encoder import ResnetEncoder from depth_decoder import DepthDecoder # from pose_decoder import PoseDecoder # from pose_cnn import PoseCNN # def disp_to_depth_no_scaling(disp): # """Convert network's sigmoid output into depth prediction # The formula for this conversion is given in the 'additional considerations' # section of the paper. # """ # depth = 1 / (disp + 1e-7) # return depth # Function to load the model def load_model(device, model_path): encoder_path = os.path.join(model_path, "encoder.pth") depth_decoder_path = os.path.join(model_path, "depth.pth") encoder = ResnetEncoder(18, False) loaded_dict_enc = torch.load(encoder_path, map_location=device) feed_height = loaded_dict_enc['height'] feed_width = loaded_dict_enc['width'] filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()} encoder.load_state_dict(filtered_dict_enc) encoder.to(device) encoder.eval() depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4)) loaded_dict = torch.load(depth_decoder_path, map_location=device) depth_decoder.load_state_dict(loaded_dict, strict=False) depth_decoder.to(device) depth_decoder.eval() return encoder, depth_decoder, feed_height, feed_width # Function to apply the model to an image def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width): input_image = image.resize((feed_width, feed_height), pil.LANCZOS) input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device) features = encoder(input_image) outputs = depth_decoder(features) disp = outputs[("disp", 0)] disp_resized = disp disp_resized_np = disp_resized.squeeze().cpu().numpy() vmax = np.percentile(disp_resized_np, 95) normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) return colormapped_im # Streamlit app def main(): st.title("Monodepth2 Depth Estimation") st.write("Upload a PNG image to get depth estimation") if not os.path.exists("depth.pth"): subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth"]) if not os.path.exists("encoder.pth"): subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth"]) uploaded_file = st.file_uploader("Choose a PNG file", type="png") if uploaded_file is not None: image = pil.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Predict Depth'): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = '.' # Update this to the path of your model encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path) colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width) depth_image = pil.fromarray(colormapped_im) st.image(depth_image, caption='Predicted Depth Image', use_column_width=True) # Convert depth image to bytes img_bytes = io.BytesIO() depth_image.save(img_bytes, format='PNG') img_bytes = img_bytes.getvalue() # Provide download link st.download_button( label="Download Depth Image", data=img_bytes, file_name="depth_image.png", mime="image/png" ) if __name__ == "__main__": main()