pamixsun commited on
Commit
0499403
1 Parent(s): 723fedd

Upload 2 files

Browse files
Files changed (2) hide show
  1. cam.py +80 -0
  2. glaucoma.py +51 -0
cam.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import cv2
3
+ import torch
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+ from typing import List, Callable, Optional
8
+ from functools import partial
9
+
10
+ from pytorch_grad_cam import GradCAM
11
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ from pytorch_grad_cam.utils.image import show_cam_on_image
13
+
14
+
15
+ """ Model wrapper to return a tensor"""
16
+ class HuggingfaceToTensorModelWrapper(torch.nn.Module):
17
+ def __init__(self, model):
18
+ super(HuggingfaceToTensorModelWrapper, self).__init__()
19
+ self.model = model
20
+
21
+ def forward(self, x):
22
+ return self.model(x).logits
23
+
24
+
25
+ class ClassActivationMap(object):
26
+ def __init__(self, model, processor):
27
+ self.model = HuggingfaceToTensorModelWrapper(model)
28
+ target_layer = model.swinv2.layernorm
29
+ self.target_layer = [target_layer]
30
+ self.processor = processor
31
+
32
+ def swinT_reshape_transform_huggingface(self, tensor, width, height):
33
+ result = tensor.reshape(tensor.size(0),
34
+ height,
35
+ width,
36
+ tensor.size(2))
37
+ result = result.transpose(2, 3).transpose(1, 2)
38
+ return result
39
+
40
+ def run_grad_cam_on_image(self,
41
+ targets_for_gradcam: List[Callable],
42
+ reshape_transform: Optional[Callable],
43
+ input_tensor: torch.nn.Module,
44
+ input_image: Image,
45
+ method: Callable=GradCAM):
46
+ with method(model=self.model,
47
+ target_layers=self.target_layer,
48
+ reshape_transform=reshape_transform) as cam:
49
+
50
+ # Replicate the tensor for each of the categories we want to create Grad-CAM for:
51
+ # print(input_tensor.size())
52
+ repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1)
53
+ # print(repeated_tensor.size())
54
+
55
+ batch_results = cam(input_tensor=repeated_tensor,
56
+ targets=targets_for_gradcam)
57
+ results = []
58
+ for grayscale_cam in batch_results:
59
+ visualization = show_cam_on_image(np.float32(input_image) / 255,
60
+ grayscale_cam,
61
+ use_rgb=True)
62
+ # Make it weight less in the notebook:
63
+ visualization = cv2.resize(visualization,
64
+ (visualization.shape[1] // 1, visualization.shape[0] // 1))
65
+ results.append(visualization)
66
+ return np.hstack(results)
67
+
68
+ def get_cam(self, image, category_id):
69
+ image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width']))
70
+ img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze()
71
+ targets_for_gradcam = [ClassifierOutputTarget(category_id)]
72
+ reshape_transform = partial(self.swinT_reshape_transform_huggingface,
73
+ width=img_tensor.shape[2] // 32,
74
+ height=img_tensor.shape[1] // 32)
75
+ cam = self.run_grad_cam_on_image(input_tensor=img_tensor,
76
+ input_image=image,
77
+ targets_for_gradcam=targets_for_gradcam,
78
+ reshape_transform=reshape_transform)
79
+
80
+ return cam
glaucoma.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+
4
+ from transformers import AutoImageProcessor, Swinv2ForImageClassification
5
+
6
+ from lib.cam import ClassActivationMap
7
+
8
+
9
+ class GlaucomaModel(object):
10
+ def __init__(self,
11
+ cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
12
+ device=torch.device('cpu')):
13
+ # where to load the model, gpu or cpu ?
14
+ self.device = device
15
+ # classification model for nails disease
16
+ self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
17
+ self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
18
+ # class activation map
19
+ self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
20
+
21
+ # classification id to label
22
+ self.id2label = self.cls_model.config.id2label
23
+
24
+ # number of classes for nails disease
25
+ self.num_diseases = len(self.id2label)
26
+
27
+ def glaucoma_pred(self, image):
28
+ """
29
+ Args:
30
+ image: image array in RGB order.
31
+ """
32
+ inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
33
+ with torch.no_grad():
34
+ inputs.to(self.device)
35
+ outputs = self.cls_model(**inputs).logits
36
+ disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
37
+
38
+ return disease_idx
39
+
40
+ def process(self, image):
41
+ """
42
+ Args:
43
+ image: image array in RGB order.
44
+ """
45
+ image_shape = image.shape[:2]
46
+ disease_idx = self.glaucoma_pred(image)
47
+ cam = self.cam.get_cam(image, disease_idx)
48
+ cam = cv2.resize(cam, image_shape[::-1])
49
+
50
+ return disease_idx, cam
51
+