import streamlit as st import numpy as np import PIL.Image as pil import torch import os import io from torchvision import transforms import matplotlib as mpl import matplotlib.cm as cm import networks from layers import disp_to_depth_no_scaling # Function to load the model def load_model(device, model_path): encoder_path = os.path.join(model_path, "encoder.pth") depth_decoder_path = os.path.join(model_path, "depth.pth") encoder = networks.ResnetEncoder(18, False) loaded_dict_enc = torch.load(encoder_path, map_location=device) feed_height = loaded_dict_enc['height'] feed_width = loaded_dict_enc['width'] filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()} encoder.load_state_dict(filtered_dict_enc) encoder.to(device) encoder.eval() depth_decoder = networks.DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4)) loaded_dict = torch.load(depth_decoder_path, map_location=device) depth_decoder.load_state_dict(loaded_dict, strict=False) depth_decoder.to(device) depth_decoder.eval() return encoder, depth_decoder, feed_height, feed_width # Function to apply the model to an image def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width): input_image = image.resize((feed_width, feed_height), pil.LANCZOS) input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device) features = encoder(input_image) outputs = depth_decoder(features) disp = outputs[("disp", 0)] disp_resized = disp disp_resized_np = disp_resized.squeeze().cpu().numpy() vmax = np.percentile(disp_resized_np, 95) normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) return colormapped_im # Streamlit app def main(): st.title("Monodepth2 Depth Estimation") st.write("Upload a PNG image to get depth estimation") if not os.path.exits("depth.pth"): subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth"]) if not os.path.exits("encoder.pth"): subprocess.run(["wget", "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth"]) uploaded_file = st.file_uploader("Choose a PNG file", type="png") if uploaded_file is not None: image = pil.open(uploaded_file).convert('RGB') st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Predict Depth'): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_path = '.' # Update this to the path of your model encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path) colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width) depth_image = pil.fromarray(colormapped_im) st.image(depth_image, caption='Predicted Depth Image', use_column_width=True) # Convert depth image to bytes img_bytes = io.BytesIO() depth_image.save(img_bytes, format='PNG') img_bytes = img_bytes.getvalue() # Provide download link st.download_button( label="Download Depth Image", data=img_bytes, file_name="depth_image.png", mime="image/png" ) if __name__ == "__main__": main()