Spaces:
Sleeping
Sleeping
feat:workbench page
Browse files- app.py +268 -95
- closest_sample.py +56 -3
- explanations.py +47 -20
- inference_resnet.py +1 -1
app.py
CHANGED
@@ -21,7 +21,8 @@ from inference_resnet import get_triplet_model
|
|
21 |
from inference_beit import get_triplet_model_beit
|
22 |
import pathlib
|
23 |
import tensorflow as tf
|
24 |
-
from closest_sample import get_images
|
|
|
25 |
|
26 |
if not os.path.exists('images'):
|
27 |
REPO_ID='Serrelab/image_examples_gradio'
|
@@ -35,6 +36,57 @@ if not os.path.exists('dataset'):
|
|
35 |
print("warning! A read token in env variables is needed for authentication.")
|
36 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
def get_model(model_name):
|
39 |
|
40 |
|
@@ -61,6 +113,13 @@ def get_model(model_name):
|
|
61 |
embedding_depth = 2,
|
62 |
n_classes = n_classes)
|
63 |
model.load_weights('model_classification/fossil-142.h5')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
else:
|
65 |
raise ValueError(f"Model name '{model_name}' is not recognized")
|
66 |
return model,n_classes
|
@@ -82,7 +141,12 @@ def classify_image(input_image, model_name):
|
|
82 |
model, n_classes= get_model(model_name)
|
83 |
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
|
84 |
return result
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
86 |
from inference_beit import inference_resnet_finer_beit
|
87 |
model,n_classes = get_model(model_name)
|
88 |
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
|
@@ -100,7 +164,12 @@ def get_embeddings(input_image,model_name):
|
|
100 |
model, n_classes= get_model(model_name)
|
101 |
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
|
102 |
return result
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
104 |
from inference_beit import inference_resnet_embedding_beit
|
105 |
model,n_classes = get_model(model_name)
|
106 |
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
|
@@ -114,30 +183,103 @@ def find_closest(input_image,model_name):
|
|
114 |
#outputs = classes+paths
|
115 |
return classes,paths
|
116 |
|
117 |
-
def
|
|
|
|
|
|
|
|
|
|
|
118 |
model,n_classes= get_model(model_name)
|
119 |
-
if model_name=='Fossils 142':
|
120 |
size = 384
|
121 |
else:
|
122 |
size = 600
|
123 |
#saliency, integrated, smoothgrad,
|
124 |
-
exp_list = explain(model,input_image,size = size, n_classes=n_classes)
|
125 |
#original = saliency + integrated + smoothgrad
|
126 |
print('done')
|
127 |
-
sobol1,sobol2,sobol3,sobol4,sobol5 = exp_list[0],exp_list[1],exp_list[2],exp_list[3],exp_list[4]
|
128 |
-
rise1,rise2,rise3,rise4,rise5 = exp_list[5],exp_list[6],exp_list[7],exp_list[8],exp_list[9]
|
129 |
-
hsic1,hsic2,hsic3,hsic4,hsic5 = exp_list[10],exp_list[11],exp_list[12],exp_list[13],exp_list[14]
|
130 |
-
saliency1,saliency2,saliency3,saliency4,saliency5 = exp_list[15],exp_list[16],exp_list[17],exp_list[18],exp_list[19]
|
131 |
-
return sobol1,sobol2,sobol3,sobol4,sobol5,rise1,rise2,rise3,rise4,rise5,hsic1,hsic2,hsic3,hsic4,hsic5,saliency1,saliency2,saliency3,saliency4,saliency5
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
#minimalist theme
|
134 |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
135 |
|
136 |
with gr.Tab(" Florrissant Fossils"):
|
137 |
-
|
138 |
with gr.Row():
|
139 |
with gr.Column():
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
classify_image_button = gr.Button("Classify Image")
|
142 |
|
143 |
# with gr.Column():
|
@@ -148,21 +290,101 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
148 |
|
149 |
with gr.Column():
|
150 |
model_name = gr.Dropdown(
|
151 |
-
["Mummified 170", "Rock 170","Fossils 142"],
|
152 |
multiselect=False,
|
153 |
-
value="Fossils
|
154 |
label="Model",
|
155 |
interactive=True,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
-
with gr.Row():
|
160 |
-
|
161 |
-
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
|
162 |
-
samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19]
|
163 |
-
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset')
|
164 |
-
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
|
165 |
-
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
|
166 |
|
167 |
# with gr.Accordion("Using Diffuser"):
|
168 |
# with gr.Column():
|
@@ -173,80 +395,20 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
173 |
# class_predicted2 = gr.Label(label='Class Predicted from diffuser')
|
174 |
# classify_button = gr.Button("Classify Image")
|
175 |
|
176 |
-
|
177 |
-
with gr.Accordion("Explanations "):
|
178 |
-
gr.Markdown("Computing Explanations from the model")
|
179 |
-
with gr.Column():
|
180 |
-
with gr.Row():
|
181 |
-
|
182 |
-
#original_input = gr.Image(label="Original Frame")
|
183 |
-
#saliency = gr.Image(label="saliency")
|
184 |
-
#gradcam = gr.Image(label='integraged gradients')
|
185 |
-
#guided_gradcam = gr.Image(label='gradcam')
|
186 |
-
#guided_backprop = gr.Image(label='guided backprop')
|
187 |
-
sobol1 = gr.Image(label = 'Sobol1')
|
188 |
-
sobol2= gr.Image(label = 'Sobol2')
|
189 |
-
sobol3= gr.Image(label = 'Sobol3')
|
190 |
-
sobol4= gr.Image(label = 'Sobol4')
|
191 |
-
sobol5= gr.Image(label = 'Sobol5')
|
192 |
-
|
193 |
-
with gr.Row():
|
194 |
-
rise1 = gr.Image(label = 'Rise1')
|
195 |
-
rise2 = gr.Image(label = 'Rise2')
|
196 |
-
rise3 = gr.Image(label = 'Rise3')
|
197 |
-
rise4 = gr.Image(label = 'Rise4')
|
198 |
-
rise5 = gr.Image(label = 'Rise5')
|
199 |
-
|
200 |
-
with gr.Row():
|
201 |
-
hsic1 = gr.Image(label = 'HSIC1')
|
202 |
-
hsic2 = gr.Image(label = 'HSIC2')
|
203 |
-
hsic3 = gr.Image(label = 'HSIC3')
|
204 |
-
hsic4 = gr.Image(label = 'HSIC4')
|
205 |
-
hsic5 = gr.Image(label = 'HSIC5')
|
206 |
-
|
207 |
-
with gr.Row():
|
208 |
-
saliency1 = gr.Image(label = 'Saliency1')
|
209 |
-
saliency2 = gr.Image(label = 'Saliency2')
|
210 |
-
saliency3 = gr.Image(label = 'Saliency3')
|
211 |
-
saliency4 = gr.Image(label = 'Saliency4')
|
212 |
-
saliency5 = gr.Image(label = 'Saliency5')
|
213 |
-
|
214 |
-
|
215 |
-
generate_explanations = gr.Button("Generate Explanations")
|
216 |
-
|
217 |
-
# with gr.Accordion('Closest Images'):
|
218 |
-
# gr.Markdown("Finding the closest images in the dataset")
|
219 |
-
# with gr.Row():
|
220 |
-
# with gr.Column():
|
221 |
-
# label_closest_image_0 = gr.Markdown('')
|
222 |
-
# closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
|
223 |
-
# with gr.Column():
|
224 |
-
# label_closest_image_1 = gr.Markdown('')
|
225 |
-
# closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
|
226 |
-
# with gr.Column():
|
227 |
-
# label_closest_image_2 = gr.Markdown('')
|
228 |
-
# closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
|
229 |
-
# with gr.Column():
|
230 |
-
# label_closest_image_3 = gr.Markdown('')
|
231 |
-
# closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
|
232 |
-
# with gr.Column():
|
233 |
-
# label_closest_image_4 = gr.Markdown('')
|
234 |
-
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
|
235 |
-
# find_closest_btn = gr.Button("Find Closest Images")
|
236 |
-
with gr.Accordion('Closest Images'):
|
237 |
-
gr.Markdown("Finding the closest images in the dataset")
|
238 |
-
|
239 |
-
with gr.Row():
|
240 |
-
gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
241 |
-
#.style(grid=[1, 5], height=200, width=200)
|
242 |
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
|
249 |
-
def
|
250 |
labels, images = find_closest(input_image,model_name)
|
251 |
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
|
252 |
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
|
@@ -255,8 +417,19 @@ with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
|
255 |
image_caption.append((images[i],labels[i]))
|
256 |
return image_caption
|
257 |
|
258 |
-
find_closest_btn.click(fn=
|
259 |
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
demo.queue() # manage multiple incoming requests
|
262 |
|
|
|
21 |
from inference_beit import get_triplet_model_beit
|
22 |
import pathlib
|
23 |
import tensorflow as tf
|
24 |
+
from closest_sample import get_images,get_diagram
|
25 |
+
|
26 |
|
27 |
if not os.path.exists('images'):
|
28 |
REPO_ID='Serrelab/image_examples_gradio'
|
|
|
36 |
print("warning! A read token in env variables is needed for authentication.")
|
37 |
snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset')
|
38 |
|
39 |
+
HEADER = '''
|
40 |
+
<h2><b>Official Gradio Demo</b></h2><h2><a href='https://huggingface.co/spaces/Serrelab/fossil_app' target='_blank'><b>Identifying Florissant Leaf Fossils to Family using Deep Neural Networks </b></a></h2>
|
41 |
+
Code: <a href='https://github.com/orgs/serre-lab/projects/2' target='_blank'>GitHub</a>. Paper: <a href='' target='_blank'>ArXiv</a>.
|
42 |
+
|
43 |
+
|
44 |
+
'''
|
45 |
+
|
46 |
+
"""
|
47 |
+
**Fossil** a brief intro to the project.
|
48 |
+
# ❗️❗️❗️**Important Notes:**
|
49 |
+
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users .
|
50 |
+
# - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users.
|
51 |
+
|
52 |
+
"""
|
53 |
+
|
54 |
+
USER_GUIDE = """
|
55 |
+
<div style='background-color: #f0f0f0; padding: 20px; border-radius: 10px;'>
|
56 |
+
<h2>❗️ User Guide</h2>
|
57 |
+
<p>Welcome to the interactive fossil exploration tool. Here's how to get started:</p>
|
58 |
+
<ul>
|
59 |
+
<li><strong>Upload an Image:</strong> Drag and drop or choose from given samples to upload images of fossils.</li>
|
60 |
+
<li><strong>Process Image:</strong> After uploading, click the 'Process Image' button to analyze the image.</li>
|
61 |
+
<li><strong>Explore Results:</strong> Switch to the 'Workbench' tab to check out detailed analysis and results.</li>
|
62 |
+
</ul>
|
63 |
+
<h3>Tips</h3>
|
64 |
+
<ul>
|
65 |
+
<li>Zoom into images on the workbench for finer details.</li>
|
66 |
+
<li>Use the examples below as references for what types of images to upload.</li>
|
67 |
+
</ul>
|
68 |
+
<p>Enjoy exploring! 🌟</p>
|
69 |
+
</div>
|
70 |
+
"""
|
71 |
+
|
72 |
+
TIPS = """
|
73 |
+
## Tips
|
74 |
+
- Zoom into images on the workbench for finer details.
|
75 |
+
- Use the examples below as references for what types of images to upload.
|
76 |
+
|
77 |
+
Enjoy exploring!
|
78 |
+
"""
|
79 |
+
CITATION = '''
|
80 |
+
📧 **Contact** <br>
|
81 |
+
If you have any questions, feel free to contact us at <b>ivan_felipe_rodriguez@brown.edu</b>.
|
82 |
+
'''
|
83 |
+
"""
|
84 |
+
📝 **Citation**
|
85 |
+
cite using this bibtex:...
|
86 |
+
```
|
87 |
+
```
|
88 |
+
📋 **License**
|
89 |
+
"""
|
90 |
def get_model(model_name):
|
91 |
|
92 |
|
|
|
113 |
embedding_depth = 2,
|
114 |
n_classes = n_classes)
|
115 |
model.load_weights('model_classification/fossil-142.h5')
|
116 |
+
elif model_name == 'Fossils new':
|
117 |
+
n_classes = 142
|
118 |
+
model = get_triplet_model_beit(input_shape = (384, 384, 3),
|
119 |
+
embedding_units = 256,
|
120 |
+
embedding_depth = 2,
|
121 |
+
n_classes = n_classes)
|
122 |
+
model.load_weights('model_classification/fossil-new.h5')
|
123 |
else:
|
124 |
raise ValueError(f"Model name '{model_name}' is not recognized")
|
125 |
return model,n_classes
|
|
|
141 |
model, n_classes= get_model(model_name)
|
142 |
result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
|
143 |
return result
|
144 |
+
elif 'Fossils 142' ==model_name:
|
145 |
+
from inference_beit import inference_resnet_finer_beit
|
146 |
+
model,n_classes = get_model(model_name)
|
147 |
+
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
|
148 |
+
return result
|
149 |
+
elif 'Fossils new' ==model_name:
|
150 |
from inference_beit import inference_resnet_finer_beit
|
151 |
model,n_classes = get_model(model_name)
|
152 |
result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes)
|
|
|
164 |
model, n_classes= get_model(model_name)
|
165 |
result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
|
166 |
return result
|
167 |
+
elif 'Fossils 142' ==model_name:
|
168 |
+
from inference_beit import inference_resnet_embedding_beit
|
169 |
+
model,n_classes = get_model(model_name)
|
170 |
+
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
|
171 |
+
return result
|
172 |
+
elif 'Fossils new' ==model_name:
|
173 |
from inference_beit import inference_resnet_embedding_beit
|
174 |
model,n_classes = get_model(model_name)
|
175 |
result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes)
|
|
|
183 |
#outputs = classes+paths
|
184 |
return classes,paths
|
185 |
|
186 |
+
def generate_diagram_closest(input_image,model_name,top_k):
|
187 |
+
embedding = get_embeddings(input_image,model_name)
|
188 |
+
diagram_path = get_diagram(embedding,top_k)
|
189 |
+
return diagram_path
|
190 |
+
|
191 |
+
def explain_image(input_image,model_name,explain_method,nb_samples):
|
192 |
model,n_classes= get_model(model_name)
|
193 |
+
if model_name=='Fossils 142' or 'Fossils new':
|
194 |
size = 384
|
195 |
else:
|
196 |
size = 600
|
197 |
#saliency, integrated, smoothgrad,
|
198 |
+
classes,exp_list = explain(model,input_image,explain_method,nb_samples,size = size, n_classes=n_classes)
|
199 |
#original = saliency + integrated + smoothgrad
|
200 |
print('done')
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
+
return classes,exp_list
|
203 |
+
|
204 |
+
def setup_examples():
|
205 |
+
paths = sorted(pathlib.Path('images/').rglob('*.jpg'))
|
206 |
+
samples = [path.as_posix() for path in paths if 'fossils' in str(path)][:19]
|
207 |
+
examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Fossils Examples from the dataset')
|
208 |
+
samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19]
|
209 |
+
examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset')
|
210 |
+
return examples_fossils,examples_leaves
|
211 |
+
|
212 |
+
def preprocess_image(image, output_size=(300, 300)):
|
213 |
+
#shape (height, width, channels)
|
214 |
+
h, w = image.shape[:2]
|
215 |
+
|
216 |
+
#padding
|
217 |
+
if h > w:
|
218 |
+
padding = (h - w) // 2
|
219 |
+
image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0])
|
220 |
+
else:
|
221 |
+
padding = (w - h) // 2
|
222 |
+
image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0])
|
223 |
+
|
224 |
+
# resize
|
225 |
+
image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA)
|
226 |
+
|
227 |
+
return image_resized
|
228 |
+
|
229 |
+
def update_display(image):
|
230 |
+
processed_image = preprocess_image(image)
|
231 |
+
instruction = "Image ready. Please switch to the 'Specimen Workbench' tab to check out further analysis and outputs."
|
232 |
+
model_name = gr.Dropdown(
|
233 |
+
["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
|
234 |
+
multiselect=False,
|
235 |
+
value="Fossils new", # default option
|
236 |
+
label="Model",
|
237 |
+
interactive=True,
|
238 |
+
info="Choose the model you'd like to use"
|
239 |
+
)
|
240 |
+
explain_method = gr.Dropdown(
|
241 |
+
["Sobol", "HSIC","Rise","Saliency"],
|
242 |
+
multiselect=False,
|
243 |
+
value="Rise", # default option
|
244 |
+
label="Explain method",
|
245 |
+
interactive=True,
|
246 |
+
info="Choose one method to explain the model"
|
247 |
+
)
|
248 |
+
sampling_size = gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
|
249 |
+
info="Choose between 1 and 5000")
|
250 |
+
|
251 |
+
top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
|
252 |
+
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
253 |
+
exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
254 |
+
closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
255 |
+
diagram= gr.Image(label = 'Bar Chart')
|
256 |
+
return processed_image,processed_image,instruction,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram
|
257 |
+
def update_slider_visibility(explain_method):
|
258 |
+
bool = explain_method=="Rise"
|
259 |
+
return {sampling_size: gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise", visible=bool, interactive=True)}
|
260 |
+
|
261 |
#minimalist theme
|
262 |
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
263 |
|
264 |
with gr.Tab(" Florrissant Fossils"):
|
265 |
+
gr.Markdown(HEADER)
|
266 |
with gr.Row():
|
267 |
with gr.Column():
|
268 |
+
gr.Markdown(USER_GUIDE)
|
269 |
+
with gr.Column(scale=2):
|
270 |
+
with gr.Column(scale=2):
|
271 |
+
instruction_text = gr.Textbox(label="Instructions", value="Upload/Choose an image and click 'Process Image'.")
|
272 |
+
input_image = gr.Image(label="Input",width="100%",container=True)
|
273 |
+
process_button = gr.Button("Process Image")
|
274 |
+
with gr.Column(scale=1):
|
275 |
+
examples_fossils,examples_leaves = setup_examples()
|
276 |
+
|
277 |
+
gr.Markdown(CITATION)
|
278 |
+
|
279 |
+
with gr.Tab("Specimen Workbench"):
|
280 |
+
with gr.Row():
|
281 |
+
with gr.Column():
|
282 |
+
workbench_image = gr.Image(label="Workbench Image")
|
283 |
classify_image_button = gr.Button("Classify Image")
|
284 |
|
285 |
# with gr.Column():
|
|
|
290 |
|
291 |
with gr.Column():
|
292 |
model_name = gr.Dropdown(
|
293 |
+
["Mummified 170", "Rock 170","Fossils 142","Fossils new"],
|
294 |
multiselect=False,
|
295 |
+
value="Fossils new", # default option
|
296 |
label="Model",
|
297 |
interactive=True,
|
298 |
+
info="Choose the model you'd like to use"
|
299 |
+
)
|
300 |
+
explain_method = gr.Dropdown(
|
301 |
+
["Sobol", "HSIC","Rise","Saliency"],
|
302 |
+
multiselect=False,
|
303 |
+
value="Rise", # default option
|
304 |
+
label="Explain method",
|
305 |
+
interactive=True,
|
306 |
+
info="Choose one method to explain the model"
|
307 |
)
|
308 |
+
# explain_method = gr.CheckboxGroup(["Sobol", "HSIC","Rise","Saliency"],
|
309 |
+
# label="explain method",
|
310 |
+
# value="Rise",
|
311 |
+
# multiselect=False,
|
312 |
+
# interactive=True,)
|
313 |
+
sampling_size = gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True,
|
314 |
+
info="Choose between 1 and 5000")
|
315 |
+
|
316 |
+
top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200")
|
317 |
+
explain_method.change(
|
318 |
+
fn=update_slider_visibility,
|
319 |
+
inputs=explain_method,
|
320 |
+
outputs=sampling_size
|
321 |
+
)
|
322 |
+
with gr.Row():
|
323 |
+
with gr.Column(scale=1):
|
324 |
class_predicted = gr.Label(label='Class Predicted',num_top_classes=10)
|
325 |
+
with gr.Column(scale=4):
|
326 |
+
with gr.Accordion("Explanations "):
|
327 |
+
gr.Markdown("Computing Explanations from the model")
|
328 |
+
with gr.Column():
|
329 |
+
with gr.Row():
|
330 |
+
|
331 |
+
#original_input = gr.Image(label="Original Frame")
|
332 |
+
#saliency = gr.Image(label="saliency")
|
333 |
+
#gradcam = gr.Image(label='integraged gradients')
|
334 |
+
#guided_gradcam = gr.Image(label='gradcam')
|
335 |
+
#guided_backprop = gr.Image(label='guided backprop')
|
336 |
+
# exp1 = gr.Image(label = 'Class_name1')
|
337 |
+
# exp2= gr.Image(label = 'Class_name2')
|
338 |
+
# exp3= gr.Image(label = 'Class_name3')
|
339 |
+
# exp4= gr.Image(label = 'Class_name4')
|
340 |
+
# exp5= gr.Image(label = 'Class_name5')
|
341 |
+
|
342 |
+
exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
343 |
+
|
344 |
+
generate_explanations = gr.Button("Generate Explanations")
|
345 |
+
|
346 |
+
# with gr.Accordion('Closest Images'):
|
347 |
+
# gr.Markdown("Finding the closest images in the dataset")
|
348 |
+
# with gr.Row():
|
349 |
+
# with gr.Column():
|
350 |
+
# label_closest_image_0 = gr.Markdown('')
|
351 |
+
# closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200)
|
352 |
+
# with gr.Column():
|
353 |
+
# label_closest_image_1 = gr.Markdown('')
|
354 |
+
# closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200)
|
355 |
+
# with gr.Column():
|
356 |
+
# label_closest_image_2 = gr.Markdown('')
|
357 |
+
# closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200)
|
358 |
+
# with gr.Column():
|
359 |
+
# label_closest_image_3 = gr.Markdown('')
|
360 |
+
# closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200)
|
361 |
+
# with gr.Column():
|
362 |
+
# label_closest_image_4 = gr.Markdown('')
|
363 |
+
# closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200)
|
364 |
+
# find_closest_btn = gr.Button("Find Closest Images")
|
365 |
+
with gr.Accordion('Closest Fossil Images'):
|
366 |
+
gr.Markdown("Finding the closest images in the dataset")
|
367 |
+
|
368 |
+
with gr.Row():
|
369 |
+
closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None)
|
370 |
+
#.style(grid=[1, 5], height=200, width=200)
|
371 |
+
|
372 |
+
find_closest_btn = gr.Button("Find Closest Images")
|
373 |
+
|
374 |
+
#segment_button.click(segment_image, inputs=input_image, outputs=segmented_image)
|
375 |
+
classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted)
|
376 |
+
# generate_exp.click(exp_image, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp1,exp2,exp3,exp4,exp5]) #
|
377 |
+
with gr.Accordion('Closest Leaves Images'):
|
378 |
+
gr.Markdown("5 closest leaves")
|
379 |
+
with gr.Accordion("Class Distribution of Closest Samples "):
|
380 |
+
gr.Markdown("Visualize class distribution of top-k closest samples in our dataset")
|
381 |
+
with gr.Column():
|
382 |
+
with gr.Row():
|
383 |
+
diagram= gr.Image(label = 'Bar Chart')
|
384 |
+
|
385 |
+
generate_diagram = gr.Button("Generate Diagram")
|
386 |
+
|
387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
|
389 |
# with gr.Accordion("Using Diffuser"):
|
390 |
# with gr.Column():
|
|
|
395 |
# class_predicted2 = gr.Label(label='Class Predicted from diffuser')
|
396 |
# classify_button = gr.Button("Classify Image")
|
397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
398 |
|
399 |
+
def update_exp_outputs(input_image,model_name,explain_method,nb_samples):
|
400 |
+
labels, images = explain_image(input_image,model_name,explain_method,nb_samples)
|
401 |
+
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
|
402 |
+
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
|
403 |
+
image_caption=[]
|
404 |
+
for i in range(5):
|
405 |
+
image_caption.append((images[i],"Predicted Class "+str(i)+": "+labels[i]))
|
406 |
+
return image_caption
|
407 |
+
|
408 |
+
generate_explanations.click(fn=update_exp_outputs, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp_gallery])
|
409 |
+
|
410 |
#find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4])
|
411 |
+
def update_closest_outputs(input_image,model_name):
|
412 |
labels, images = find_closest(input_image,model_name)
|
413 |
#labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels])
|
414 |
#labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>"
|
|
|
417 |
image_caption.append((images[i],labels[i]))
|
418 |
return image_caption
|
419 |
|
420 |
+
find_closest_btn.click(fn=update_closest_outputs, inputs=[input_image,model_name], outputs=[closest_gallery])
|
421 |
#classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted)
|
422 |
+
|
423 |
+
generate_diagram.click(generate_diagram_closest, inputs=[input_image,model_name,top_k], outputs=diagram)
|
424 |
+
|
425 |
+
process_button.click(
|
426 |
+
fn=update_display,
|
427 |
+
inputs=input_image,
|
428 |
+
outputs=[input_image,workbench_image,instruction_text,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram]
|
429 |
+
)
|
430 |
+
|
431 |
+
|
432 |
+
|
433 |
|
434 |
demo.queue() # manage multiple incoming requests
|
435 |
|
closest_sample.py
CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
|
|
5 |
import os
|
6 |
from huggingface_hub import snapshot_download
|
7 |
import requests
|
|
|
|
|
8 |
|
9 |
|
10 |
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
|
@@ -23,7 +25,7 @@ embedding_fossils = np.load('dataset/embedding_fossils_170_finer.npy')
|
|
23 |
|
24 |
fossils_pd= pd.read_csv('fossils_paths.csv')
|
25 |
|
26 |
-
def pca_distance(pca,sample,embedding):
|
27 |
"""
|
28 |
Args:
|
29 |
pca:fitted PCA model
|
@@ -35,7 +37,7 @@ def pca_distance(pca,sample,embedding):
|
|
35 |
s = pca.transform(sample.reshape(1,-1))
|
36 |
all = pca.transform(embedding[:,-1])
|
37 |
distances = np.linalg.norm(all - s, axis=1)
|
38 |
-
return np.argsort(distances)[:
|
39 |
|
40 |
def return_paths(argsorted,files):
|
41 |
paths= []
|
@@ -56,7 +58,7 @@ def get_images(embedding):
|
|
56 |
|
57 |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
58 |
|
59 |
-
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils)
|
60 |
|
61 |
fossils_paths = fossils_pd['file_name'].values
|
62 |
|
@@ -87,3 +89,54 @@ def get_images(embedding):
|
|
87 |
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
|
88 |
|
89 |
return classes, local_paths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
6 |
from huggingface_hub import snapshot_download
|
7 |
import requests
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from collections import Counter
|
10 |
|
11 |
|
12 |
pca_fossils = pk.load(open('pca_fossils_170_finer.pkl','rb'))
|
|
|
25 |
|
26 |
fossils_pd= pd.read_csv('fossils_paths.csv')
|
27 |
|
28 |
+
def pca_distance(pca,sample,embedding,top_k):
|
29 |
"""
|
30 |
Args:
|
31 |
pca:fitted PCA model
|
|
|
37 |
s = pca.transform(sample.reshape(1,-1))
|
38 |
all = pca.transform(embedding[:,-1])
|
39 |
distances = np.linalg.norm(all - s, axis=1)
|
40 |
+
return np.argsort(distances)[:top_k]
|
41 |
|
42 |
def return_paths(argsorted,files):
|
43 |
paths= []
|
|
|
58 |
|
59 |
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
60 |
|
61 |
+
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=5)
|
62 |
|
63 |
fossils_paths = fossils_pd['file_name'].values
|
64 |
|
|
|
89 |
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
|
90 |
|
91 |
return classes, local_paths
|
92 |
+
|
93 |
+
def get_diagram(embedding,top_k):
|
94 |
+
|
95 |
+
#pca_embedding_fossils = pca_fossils.transform(embedding_fossils[:,-1])
|
96 |
+
|
97 |
+
pca_d =pca_distance(pca_fossils,embedding,embedding_fossils,top_k=top_k)
|
98 |
+
|
99 |
+
fossils_paths = fossils_pd['file_name'].values
|
100 |
+
|
101 |
+
paths = return_paths(pca_d,fossils_paths)
|
102 |
+
#print(paths)
|
103 |
+
|
104 |
+
folder_florissant = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/Florissant_Fossil_v2.0/'
|
105 |
+
folder_general = 'https://storage.googleapis.com/serrelab/prj_fossils/2024/General_Fossil_v2.0/'
|
106 |
+
|
107 |
+
classes = []
|
108 |
+
for i, path in enumerate(paths):
|
109 |
+
local_file_path = f'image_{i}.jpg'
|
110 |
+
if 'Florissant_Fossil/512/full/jpg/' in path:
|
111 |
+
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/', folder_florissant)
|
112 |
+
elif 'General_Fossil/512/full/jpg/' in path:
|
113 |
+
public_path = path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/General_Fossil/512/full/jpg/', folder_general)
|
114 |
+
else:
|
115 |
+
print("no match found")
|
116 |
+
print(public_path)
|
117 |
+
#download_public_image(public_path, local_file_path)
|
118 |
+
parts = [part for part in public_path.split('/') if part]
|
119 |
+
part = parts[-2]
|
120 |
+
classes.append(part)
|
121 |
+
#local_paths.append(local_file_path)
|
122 |
+
#paths= [path.replace('/gpfs/data/tserre/irodri15/Fossils/new_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/512/full/jpg/',
|
123 |
+
# '/media/data_cifs/projects/prj_fossils/data/processed_data/leavesdb-v1_1/images/Fossil/Florissant_Fossil/original/full/jpg/') for path in paths]
|
124 |
+
class_counts = Counter(classes)
|
125 |
+
|
126 |
+
sorted_class_counts = sorted(class_counts.items(), key=lambda item: item[1], reverse=True)
|
127 |
+
sorted_classes, sorted_frequencies = zip(*sorted_class_counts)
|
128 |
+
colors = plt.cm.viridis(np.linspace(0, 1, len(sorted_classes)))
|
129 |
+
fig, ax = plt.subplots()
|
130 |
+
ax.bar(sorted_classes, sorted_frequencies,color=colors)
|
131 |
+
ax.set_xlabel('Class Label')
|
132 |
+
ax.set_ylabel('Frequency')
|
133 |
+
ax.set_title('Distribution of '+str(top_k) +' Closest Sample Classes')
|
134 |
+
ax.set_xticklabels(class_counts.keys(), rotation=45, ha='right')
|
135 |
+
|
136 |
+
# Save the diagram to a file
|
137 |
+
diagram_path = 'class_distribution_chart.png'
|
138 |
+
plt.tight_layout() # Adjust layout to make room for rotated x-axis labels
|
139 |
+
plt.savefig(diagram_path)
|
140 |
+
plt.close() # Close the figure to free up memory
|
141 |
+
|
142 |
+
return diagram_path
|
explanations.py
CHANGED
@@ -7,6 +7,7 @@ from xplique.attributions.global_sensitivity_analysis import LatinHypercube
|
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
|
|
|
10 |
BATCH_SIZE = 1
|
11 |
|
12 |
def show(img, p=False, **kwargs):
|
@@ -35,7 +36,7 @@ def show(img, p=False, **kwargs):
|
|
35 |
|
36 |
|
37 |
|
38 |
-
def explain(model, input_image,size=600, n_classes=171) :
|
39 |
"""
|
40 |
Generate explanations for a given model and dataset.
|
41 |
:param model: The model to explain.
|
@@ -45,31 +46,55 @@ def explain(model, input_image,size=600, n_classes=171) :
|
|
45 |
:param batch_size: The batch size to use.
|
46 |
:return: The explanations.
|
47 |
"""
|
48 |
-
|
49 |
# we only need the classification part of the model
|
50 |
class_model = tf.keras.Model(model.input, model.output[1])
|
51 |
|
52 |
-
explainers = [
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
|
58 |
-
Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
|
59 |
-
preservation_probability=0.5),
|
60 |
-
HsicAttributionMethod(class_model,
|
61 |
grid_size=7, nb_design=1500,
|
62 |
-
sampler = LatinHypercube(binary=True))
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
cropped,repetitions = _clever_crop(input_image,(size,size))
|
68 |
-
size_repetitions = int(size//(repetitions.numpy()+1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
X = preprocess(cropped,size=size)
|
70 |
predictions = class_model.predict(np.array([X]))
|
71 |
#Y = np.argmax(predictions)
|
72 |
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
|
|
|
|
|
|
|
73 |
#print(top_5_indices)
|
74 |
X = np.expand_dims(X, 0)
|
75 |
explanations = []
|
@@ -81,8 +106,10 @@ def explain(model, input_image,size=600, n_classes=171) :
|
|
81 |
phi = np.abs(explainer(X, Y))[0]
|
82 |
if len(phi.shape) == 3:
|
83 |
phi = np.mean(phi, -1)
|
84 |
-
show(X[0]
|
85 |
-
show(phi
|
|
|
|
|
86 |
plt.savefig(f'phi_{e}{i}.png')
|
87 |
explanations.append(f'phi_{e}{i}.png')
|
88 |
# avg=[]
|
@@ -101,4 +128,4 @@ def explain(model, input_image,size=600, n_classes=171) :
|
|
101 |
if len(explanations)==1:
|
102 |
explanations = explanations[0]
|
103 |
# return explanations,avg
|
104 |
-
return explanations
|
|
|
7 |
import numpy as np
|
8 |
import matplotlib.pyplot as plt
|
9 |
from inference_resnet import inference_resnet_finer, preprocess, _clever_crop
|
10 |
+
from labels import lookup_140
|
11 |
BATCH_SIZE = 1
|
12 |
|
13 |
def show(img, p=False, **kwargs):
|
|
|
36 |
|
37 |
|
38 |
|
39 |
+
def explain(model, input_image,explain_method,nb_samples,size=600, n_classes=171) :
|
40 |
"""
|
41 |
Generate explanations for a given model and dataset.
|
42 |
:param model: The model to explain.
|
|
|
46 |
:param batch_size: The batch size to use.
|
47 |
:return: The explanations.
|
48 |
"""
|
49 |
+
print('using explain_method:',explain_method)
|
50 |
# we only need the classification part of the model
|
51 |
class_model = tf.keras.Model(model.input, model.output[1])
|
52 |
|
53 |
+
explainers = []
|
54 |
+
if explain_method=="Sobol":
|
55 |
+
explainers.append(SobolAttributionMethod(class_model, grid_size=8, nb_design=32))
|
56 |
+
if explain_method=="HSIC":
|
57 |
+
explainers.append(HsicAttributionMethod(class_model,
|
|
|
|
|
|
|
|
|
58 |
grid_size=7, nb_design=1500,
|
59 |
+
sampler = LatinHypercube(binary=True)))
|
60 |
+
if explain_method=="Rise":
|
61 |
+
explainers.append(Rise(class_model,nb_samples = nb_samples, batch_size = BATCH_SIZE,grid_size=15,
|
62 |
+
preservation_probability=0.5))
|
63 |
+
if explain_method=="Saliency":
|
64 |
+
explainers.append(Saliency(class_model))
|
65 |
+
|
66 |
+
# explainers = [
|
67 |
+
# #Sobol, RISE, HSIC, Saliency
|
68 |
+
# #IntegratedGradients(class_model, steps=50, batch_size=BATCH_SIZE),
|
69 |
+
# #SmoothGrad(class_model, nb_samples=50, batch_size=BATCH_SIZE),
|
70 |
+
# #GradCAM(class_model),
|
71 |
+
# SobolAttributionMethod(class_model, grid_size=8, nb_design=32),
|
72 |
+
# HsicAttributionMethod(class_model,
|
73 |
+
# grid_size=7, nb_design=1500,
|
74 |
+
# sampler = LatinHypercube(binary=True)),
|
75 |
+
# Saliency(class_model),
|
76 |
+
# Rise(class_model,nb_samples = 5000, batch_size = BATCH_SIZE,grid_size=15,
|
77 |
+
# preservation_probability=0.5),
|
78 |
+
# #
|
79 |
+
# ]
|
80 |
+
|
81 |
cropped,repetitions = _clever_crop(input_image,(size,size))
|
82 |
+
# size_repetitions = int(size//(repetitions.numpy()+1))
|
83 |
+
# print(size)
|
84 |
+
# print(type(input_image))
|
85 |
+
# print(input_image.shape)
|
86 |
+
# size_repetitions = int(size//(repetitions+1))
|
87 |
+
# print(type(repetitions))
|
88 |
+
# print(repetitions)
|
89 |
+
# print(size_repetitions)
|
90 |
+
# print(type(size_repetitions))
|
91 |
X = preprocess(cropped,size=size)
|
92 |
predictions = class_model.predict(np.array([X]))
|
93 |
#Y = np.argmax(predictions)
|
94 |
top_5_indices = np.argsort(predictions[0])[-5:][::-1]
|
95 |
+
classes = []
|
96 |
+
for index in top_5_indices:
|
97 |
+
classes.append(lookup_140[index])
|
98 |
#print(top_5_indices)
|
99 |
X = np.expand_dims(X, 0)
|
100 |
explanations = []
|
|
|
106 |
phi = np.abs(explainer(X, Y))[0]
|
107 |
if len(phi.shape) == 3:
|
108 |
phi = np.mean(phi, -1)
|
109 |
+
show(X[0])
|
110 |
+
show(phi, p=1, alpha=0.4)
|
111 |
+
# show(X[0][:,size_repetitions:2*size_repetitions,:])
|
112 |
+
# show(phi[:,size_repetitions:2*size_repetitions], p=1, alpha=0.4)
|
113 |
plt.savefig(f'phi_{e}{i}.png')
|
114 |
explanations.append(f'phi_{e}{i}.png')
|
115 |
# avg=[]
|
|
|
128 |
if len(explanations)==1:
|
129 |
explanations = explanations[0]
|
130 |
# return explanations,avg
|
131 |
+
return classes,explanations
|
inference_resnet.py
CHANGED
@@ -7,7 +7,7 @@ else:
|
|
7 |
|
8 |
from keras.applications import resnet
|
9 |
import tensorflow.keras.layers as L
|
10 |
-
import os
|
11 |
|
12 |
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
13 |
import matplotlib.pyplot as plt
|
|
|
7 |
|
8 |
from keras.applications import resnet
|
9 |
import tensorflow.keras.layers as L
|
10 |
+
import os
|
11 |
|
12 |
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
|
13 |
import matplotlib.pyplot as plt
|