import streamlit as st import numpy as np import PIL.Image as pil import torch import os import io from torchvision import transforms import matplotlib as mpl import matplotlib.cm as cm from resnet_encoder import ResnetEncoder from depth_decoder import DepthDecoder import os import subprocess import requests def download_model_files(): depth_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth" encoder_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth" if not os.path.exists("depth.pth"): response = requests.get(depth_url) with open("depth.pth", "wb") as f: f.write(response.content) if not os.path.exists("encoder.pth"): response = requests.get(encoder_url) with open("encoder.pth", "wb") as f: f.write(response.content) # from pose_decoder import PoseDecoder # from pose_cnn import PoseCNN # def disp_to_depth_no_scaling(disp): # """Convert network's sigmoid output into depth prediction # The formula for this conversion is given in the 'additional considerations' # section of the paper. # """ # depth = 1 / (disp + 1e-7) # return depth # Function to load the model def load_model(device, model_path): encoder_path = os.path.join(model_path, "encoder.pth") depth_decoder_path = os.path.join(model_path, "depth.pth") encoder = ResnetEncoder(18, False) loaded_dict_enc = torch.load(encoder_path, map_location=device) feed_height = loaded_dict_enc['height'] feed_width = loaded_dict_enc['width'] filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()} encoder.load_state_dict(filtered_dict_enc) encoder.to(device) encoder.eval() depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4)) loaded_dict = torch.load(depth_decoder_path, map_location=device) depth_decoder.load_state_dict(loaded_dict, strict=False) depth_decoder.to(device) depth_decoder.eval() return encoder, depth_decoder, feed_height, feed_width # Function to apply the model to an image def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width): input_image = image.resize((feed_width, feed_height), pil.LANCZOS) input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device) features = encoder(input_image) outputs = depth_decoder(features) disp = outputs[("disp", 0)] disp_resized = disp disp_resized_np = disp_resized.squeeze().cpu().numpy() vmax = np.percentile(disp_resized_np, 95) normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) return colormapped_im # Streamlit app def main(): st.title("Self & Semi-Supervised Methods for Depth & Pose Estimation") st.write("Upload a PNG image to get depth estimation") download_model_files() uploaded_file = st.file_uploader("Choose a PNG file", type="png") if uploaded_file is not None: image = pil.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Predict Depth'): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = '.' # Update this to the path of your model encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path) encoder.eval() depth_decoder.eval() with torch.no_grad(): colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width) depth_image = pil.fromarray(colormapped_im) st.image(depth_image, caption='Predicted Depth Image', use_column_width=True) # Convert depth image to bytes img_bytes = io.BytesIO() depth_image.save(img_bytes, format='PNG') img_bytes = img_bytes.getvalue() # Provide download link st.download_button( label="Download Depth Image", data=img_bytes, file_name="depth_image.png", mime="image/png" ) if __name__ == "__main__": main()