andy-wyx commited on
Commit
99ddcfc
·
1 Parent(s): 6dd79e3

debugging: xai output distortion

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. explanations.py +2 -6
app.py CHANGED
@@ -141,7 +141,7 @@ def classify_image(input_image, model_name):
141
  from inference_resnet import inference_resnet_finer
142
  model,n_classes= get_model(model_name)
143
  result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
144
- return result
145
  elif 'Mummified 170' ==model_name:
146
  from inference_resnet import inference_resnet_finer
147
  model, n_classes= get_model(model_name)
@@ -169,7 +169,7 @@ def get_embeddings(input_image,model_name):
169
  from inference_resnet import inference_resnet_embedding
170
  model,n_classes= get_model(model_name)
171
  result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
172
- return result
173
  elif 'Mummified 170' ==model_name:
174
  from inference_resnet import inference_resnet_embedding
175
  model, n_classes= get_model(model_name)
 
141
  from inference_resnet import inference_resnet_finer
142
  model,n_classes= get_model(model_name)
143
  result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes)
144
+ return result
145
  elif 'Mummified 170' ==model_name:
146
  from inference_resnet import inference_resnet_finer
147
  model, n_classes= get_model(model_name)
 
169
  from inference_resnet import inference_resnet_embedding
170
  model,n_classes= get_model(model_name)
171
  result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes)
172
+ return result
173
  elif 'Mummified 170' ==model_name:
174
  from inference_resnet import inference_resnet_embedding
175
  model, n_classes= get_model(model_name)
explanations.py CHANGED
@@ -29,13 +29,8 @@ def preprocess_image(image, output_size=(300, 300)):
29
  return image_resized
30
 
31
  def show(img, output_size,p=False, **kwargs):
32
- img = np.array(img, dtype=np.float32)
33
- h, w = img.shape[:2]
34
- print(h,w)
35
 
36
- img = preprocess_image(img, output_size=(output_size,output_size))
37
- h, w = img.shape[:2]
38
- print(h,w)
39
 
40
  # check if channel first
41
  if img.shape[0] == 1:
@@ -53,6 +48,7 @@ def show(img, output_size,p=False, **kwargs):
53
  # check if clip percentile
54
  if p is not False:
55
  img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
 
56
  plt.imshow(img, **kwargs)
57
  plt.axis('off')
58
 
 
29
  return image_resized
30
 
31
  def show(img, output_size,p=False, **kwargs):
 
 
 
32
 
33
+ #img = preprocess_image(img, output_size=(output_size,output_size))
 
 
34
 
35
  # check if channel first
36
  if img.shape[0] == 1:
 
48
  # check if clip percentile
49
  if p is not False:
50
  img = np.clip(img, np.percentile(img, p), np.percentile(img, 100-p))
51
+ img = preprocess_image(img, output_size=(output_size,output_size))
52
  plt.imshow(img, **kwargs)
53
  plt.axis('off')
54