File size: 4,456 Bytes
e86792f
 
 
 
 
 
 
 
 
d4950f4
73180f7
 
33485a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73180f7
 
e86792f
d4950f4
 
 
 
 
 
 
 
e86792f
 
 
 
 
 
73180f7
e86792f
 
 
 
 
 
 
 
73180f7
e86792f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86948bc
e86792f
 
33485a1
 
e86792f
 
 
 
 
 
 
 
 
 
 
9ca77c2
 
2f16ce2
 
 
e86792f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import streamlit as st
import numpy as np
import PIL.Image as pil
import torch
import os
import io
from torchvision import transforms
import matplotlib as mpl
import matplotlib.cm as cm

from resnet_encoder import ResnetEncoder
from depth_decoder import DepthDecoder

import os
import subprocess
import requests

def download_model_files():
    depth_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/depth.pth"
    encoder_url = "https://github.com/meghakalia/depthEstimationColonoscopy/releases/download/0.0.1/encoder.pth"

    if not os.path.exists("depth.pth"):
        response = requests.get(depth_url)
        with open("depth.pth", "wb") as f:
            f.write(response.content)
        
    if not os.path.exists("encoder.pth"):
        response = requests.get(encoder_url)
        with open("encoder.pth", "wb") as f:
            f.write(response.content)


# from pose_decoder import PoseDecoder
# from pose_cnn import PoseCNN

# def disp_to_depth_no_scaling(disp):
#     """Convert network's sigmoid output into depth prediction
#     The formula for this conversion is given in the 'additional considerations'
#     section of the paper.
#     """
#     depth = 1 / (disp + 1e-7)
#     return depth
    
# Function to load the model
def load_model(device, model_path):

    encoder_path = os.path.join(model_path, "encoder.pth")
    depth_decoder_path = os.path.join(model_path, "depth.pth")

    encoder = ResnetEncoder(18, False)
    loaded_dict_enc = torch.load(encoder_path, map_location=device)
    feed_height = loaded_dict_enc['height']
    feed_width = loaded_dict_enc['width']
    filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in encoder.state_dict()}
    encoder.load_state_dict(filtered_dict_enc)
    encoder.to(device)
    encoder.eval()

    depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
    loaded_dict = torch.load(depth_decoder_path, map_location=device)
    depth_decoder.load_state_dict(loaded_dict, strict=False)
    depth_decoder.to(device)
    depth_decoder.eval()

    return encoder, depth_decoder, feed_height, feed_width

# Function to apply the model to an image
def predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width):
    input_image = image.resize((feed_width, feed_height), pil.LANCZOS)
    input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
    features = encoder(input_image)
    outputs = depth_decoder(features)
    disp = outputs[("disp", 0)]
    disp_resized = disp

    disp_resized_np = disp_resized.squeeze().cpu().numpy()
    vmax = np.percentile(disp_resized_np, 95)
    normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax)
    mapper = cm.ScalarMappable(norm=normalizer, cmap='magma')
    colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8)
    return colormapped_im

# Streamlit app
def main():
    st.title("Self & Semi-Supervised Methods for Depth & Pose Estimation")
    st.write("Upload a PNG image to get depth estimation")

    download_model_files()
    
    uploaded_file = st.file_uploader("Choose a PNG file", type="png")
    if uploaded_file is not None:
        image = pil.open(uploaded_file).convert('RGB')
        st.image(image, caption='Uploaded Image', use_column_width=True)

        if st.button('Predict Depth'):
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            
            model_path = '.'  # Update this to the path of your model
            encoder, depth_decoder, feed_height, feed_width = load_model(device, model_path)

            encoder.eval()
            depth_decoder.eval()
            with torch.no_grad():
                colormapped_im = predict_depth(image, encoder, depth_decoder, device, feed_height, feed_width)
                depth_image = pil.fromarray(colormapped_im)

            st.image(depth_image, caption='Predicted Depth Image', use_column_width=True)

            # Convert depth image to bytes
            img_bytes = io.BytesIO()
            depth_image.save(img_bytes, format='PNG')
            img_bytes = img_bytes.getvalue()

            # Provide download link
            st.download_button(
                label="Download Depth Image",
                data=img_bytes,
                file_name="depth_image.png",
                mime="image/png"
            )

if __name__ == "__main__":
    main()