vijul.shah commited on
Commit
57d7ed3
1 Parent(s): c34dd19
Files changed (1) hide show
  1. app.py +229 -0
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # takn from: https://huggingface.co/spaces/frgfm/torch-cam/blob/main/app.py
2
+
3
+ # streamlit run app.py
4
+ from io import BytesIO
5
+ import os
6
+ import sys
7
+ import matplotlib.pyplot as plt
8
+ import requests
9
+ import streamlit as st
10
+ import torch
11
+ from PIL import Image
12
+ from torchvision import models
13
+ from torchvision.transforms.functional import normalize, resize, to_pil_image, to_tensor
14
+ from torchvision import transforms
15
+
16
+ from torchcam.methods import CAM
17
+ from torchcam import methods as torchcam_methods
18
+ from torchcam.utils import overlay_mask
19
+ import os.path as osp
20
+
21
+ root_path = osp.abspath(osp.join(__file__, osp.pardir))
22
+ sys.path.append(root_path)
23
+
24
+ from utils import get_model
25
+ from registry_utils import import_registered_modules
26
+
27
+ import_registered_modules()
28
+ # from torchcam.methods._utils import locate_candidate_layer
29
+
30
+ CAM_METHODS = [
31
+ "CAM",
32
+ # "GradCAM",
33
+ # "GradCAMpp",
34
+ # "SmoothGradCAMpp",
35
+ # "ScoreCAM",
36
+ # "SSCAM",
37
+ # "ISCAM",
38
+ # "XGradCAM",
39
+ # "LayerCAM",
40
+ ]
41
+ TV_MODELS = [
42
+ "resnet18",
43
+ # "resnet50",
44
+ ]
45
+ SR_METHODS = ["GFPGAN", "RealESRGAN", "SRResNet", "CodeFormer", "HAT"]
46
+ UPSCALE = ["2", "3", "4"]
47
+ LABEL_MAP = [
48
+ "left_eye",
49
+ "right_eye",
50
+ ]
51
+
52
+
53
+ @torch.no_grad()
54
+ def _load_model(model_configs, device="cpu"):
55
+ model_path = os.path.join(root_path, model_configs["model_path"])
56
+ model_configs.pop("model_path")
57
+ model_dict = torch.load(model_path, map_location=device)
58
+ model = get_model(model_configs=model_configs)
59
+ model.load_state_dict(model_dict)
60
+ model = model.to(device)
61
+ model = model.eval()
62
+ return model
63
+
64
+
65
+ def main():
66
+ # Wide mode
67
+ st.set_page_config(page_title="Pupil Diameter Estimator", layout="wide")
68
+
69
+ # Designing the interface
70
+ st.title("EyeDentify Playground")
71
+ # For newline
72
+ st.write("\n")
73
+ # Set the columns
74
+ cols = st.columns((1, 1))
75
+ # cols = st.columns((1, 1, 1))
76
+ cols[0].header("Input image")
77
+ # cols[1].header("Raw CAM")
78
+ cols[-1].header("Prediction")
79
+
80
+ # Sidebar
81
+ # File selection
82
+ st.sidebar.title("Input selection")
83
+ # Disabling warning
84
+ st.set_option("deprecation.showfileUploaderEncoding", False)
85
+ # Choose your own image
86
+ uploaded_file = st.sidebar.file_uploader(
87
+ "Upload files", type=["png", "jpeg", "jpg"]
88
+ )
89
+ if uploaded_file is not None:
90
+ img = Image.open(BytesIO(uploaded_file.read()), mode="r").convert("RGB")
91
+
92
+ cols[0].image(img, use_column_width=True)
93
+
94
+ # Model selection
95
+ st.sidebar.title("Setup")
96
+ tv_model = st.sidebar.selectbox(
97
+ "Classification model",
98
+ TV_MODELS,
99
+ help="Supported models from Torchvision",
100
+ )
101
+
102
+ # class_choices = [
103
+ # f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)
104
+ # ]
105
+ # class_selection = st.sidebar.selectbox(
106
+ # "Class selection", ["Predicted class (argmax)", *class_choices]
107
+ # )
108
+
109
+ img_configs = {"img_size": [32, 64], "means": None, "stds": None}
110
+ # For newline
111
+ st.sidebar.write("\n")
112
+
113
+ if st.sidebar.button("Compute CAM"):
114
+ if uploaded_file is None:
115
+ st.sidebar.error("Please upload an image first")
116
+
117
+ else:
118
+ with st.spinner("Analyzing..."):
119
+
120
+ preprocess_steps = [transforms.ToTensor()]
121
+
122
+ image_size = img_configs["img_size"]
123
+ if image_size is not None:
124
+ preprocess_steps.append(
125
+ transforms.Resize(
126
+ [image_size[0], image_size[-1]],
127
+ interpolation=transforms.InterpolationMode.BICUBIC,
128
+ antialias=True,
129
+ )
130
+ )
131
+
132
+ means = img_configs["means"]
133
+ stds = img_configs["stds"]
134
+ if means is not None and stds is not None:
135
+ preprocess_steps.append(transforms.Normalize(means, stds))
136
+
137
+ preprocess_function = transforms.Compose(preprocess_steps)
138
+ input_img = preprocess_function(img)
139
+ input_img = input_img.unsqueeze(0).to(device="cpu")
140
+
141
+ model_configs = {
142
+ "model_path": root_path
143
+ + "/pre_trained_models/ResNet18/left_eye.pt",
144
+ "registered_model_name": "ResNet18",
145
+ "num_classes": 1,
146
+ }
147
+ registered_model_name = model_configs["registered_model_name"]
148
+ # default_layer = ""
149
+ if tv_model is not None:
150
+ with st.spinner("Loading model..."):
151
+ model = _load_model(model_configs)
152
+
153
+ if torch.cuda.is_available():
154
+ model = model.cuda()
155
+
156
+ if registered_model_name == "ResNet18":
157
+ target_layer = model.resnet.layer4[-1].conv2
158
+ elif registered_model_name == "ResNet50":
159
+ target_layer = model.resnet.layer4[-1].conv3
160
+ else:
161
+ raise Exception(
162
+ f"No target layer available for selected model: {registered_model_name}"
163
+ )
164
+
165
+ # target_layer = st.sidebar.text_input(
166
+ # "Target layer",
167
+ # default_layer,
168
+ # help='If you want to target several layers, add a "+" separator (e.g. "layer3+layer4")',
169
+ # )
170
+ cam_method = "CAM"
171
+ # cam_method = st.sidebar.selectbox(
172
+ # "CAM method",
173
+ # CAM_METHODS,
174
+ # help="The way your class activation map will be computed",
175
+ # )
176
+ if cam_method is not None:
177
+ # cam_extractor = methods.__dict__[cam_method](
178
+ # model,
179
+ # target_layer=(
180
+ # [s.strip() for s in target_layer.split("+")]
181
+ # if len(target_layer) > 0
182
+ # else None
183
+ # ),
184
+ # )
185
+ cam_extractor = torchcam_methods.__dict__[cam_method](
186
+ model,
187
+ target_layer=target_layer,
188
+ fc_layer=model.resnet.fc,
189
+ input_shape=(3, 32, 64),
190
+ )
191
+ # with torch.no_grad():
192
+ # if input_mask is not None:
193
+ # out = self.model(input_img, input_mask)
194
+ # else:
195
+ # out = self.model(input_img)
196
+ # activation_map = cam_extractor(class_idx=target_class)
197
+
198
+ # Forward the image to the model
199
+ out = model(input_img)
200
+ print("out = ", out)
201
+
202
+ # Select the target class
203
+ # if class_selection == "Predicted class (argmax)":
204
+ # class_idx = out.squeeze(0).argmax().item()
205
+ # else:
206
+ # class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
207
+
208
+ # Retrieve the CAM
209
+ # act_maps = cam_extractor(class_idx=target_class)
210
+ act_maps = cam_extractor(0, out)
211
+ # Fuse the CAMs if there are several
212
+ activation_map = (
213
+ act_maps[0]
214
+ if len(act_maps) == 1
215
+ else cam_extractor.fuse_cams(act_maps)
216
+ )
217
+
218
+ # Overlayed CAM
219
+ fig, ax = plt.subplots()
220
+ result = overlay_mask(
221
+ img, to_pil_image(activation_map, mode="F"), alpha=0.5
222
+ )
223
+ ax.imshow(result)
224
+ ax.axis("off")
225
+ cols[-1].pyplot(fig)
226
+
227
+
228
+ if __name__ == "__main__":
229
+ main()