Spaces:
Sleeping
Sleeping
Update visualization
Browse files- visualization +2 -2
visualization
CHANGED
|
@@ -33,14 +33,14 @@ model = vit_LRP(pretrained=True).cuda()
|
|
| 33 |
model.eval()
|
| 34 |
attribution_generator = LRP(model)
|
| 35 |
|
| 36 |
-
def generate_visualization(original_image, class_index=None):
|
| 37 |
transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
|
| 38 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|
| 39 |
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
|
| 40 |
transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
|
| 41 |
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
|
| 42 |
|
| 43 |
-
if
|
| 44 |
transformer_attribution = transformer_attribution * 255
|
| 45 |
transformer_attribution = transformer_attribution.astype(np.uint8)
|
| 46 |
ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
|
|
| 33 |
model.eval()
|
| 34 |
attribution_generator = LRP(model)
|
| 35 |
|
| 36 |
+
def generate_visualization(original_image, class_index=None, use_threshold=False):
|
| 37 |
transformer_attribution = attribution_generator.generate_LRP(original_image.unsqueeze(0).cuda(), method="transformer_attribution", index=class_index).detach()
|
| 38 |
transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
|
| 39 |
transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
|
| 40 |
transformer_attribution = transformer_attribution.reshape(224, 224).data.cpu().numpy()
|
| 41 |
transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
|
| 42 |
|
| 43 |
+
if use_threshold:
|
| 44 |
transformer_attribution = transformer_attribution * 255
|
| 45 |
transformer_attribution = transformer_attribution.astype(np.uint8)
|
| 46 |
ret, transformer_attribution = cv2.threshold(transformer_attribution, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|