pamixsun commited on
Commit
5bfee0b
·
1 Parent(s): 8dcbbbd

Update glaucoma.py

Browse files
Files changed (1) hide show
  1. glaucoma.py +45 -7
glaucoma.py CHANGED
@@ -1,28 +1,40 @@
1
  import cv2
2
  import torch
3
 
4
- from transformers import AutoImageProcessor, Swinv2ForImageClassification
5
 
6
- from cam import ClassActivationMap
 
 
 
 
7
 
8
 
9
  class GlaucomaModel(object):
10
  def __init__(self,
11
  cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
 
12
  device=torch.device('cpu')):
13
  # where to load the model, gpu or cpu ?
14
  self.device = device
15
- # glaucoma classification model
16
  self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
17
  self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
 
 
 
18
  # class activation map
19
  self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
20
 
21
  # classification id to label
22
- self.id2label = self.cls_model.config.id2label
 
 
23
 
24
- # number of classes
25
- self.num_diseases = len(self.id2label)
 
 
26
 
27
  def glaucoma_pred(self, image):
28
  """
@@ -36,6 +48,29 @@ class GlaucomaModel(object):
36
  disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
37
 
38
  return disease_idx
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def process(self, image):
41
  """
@@ -46,6 +81,9 @@ class GlaucomaModel(object):
46
  disease_idx = self.glaucoma_pred(image)
47
  cam = self.cam.get_cam(image, disease_idx)
48
  cam = cv2.resize(cam, image_shape[::-1])
 
 
 
49
 
50
- return disease_idx, cam
51
 
 
1
  import cv2
2
  import torch
3
 
4
+ import numpy as np
5
 
6
+ from torch import nn
7
+ from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
8
+
9
+ from lib.cam import ClassActivationMap
10
+ from lib.utils import add_mask, simple_vcdr
11
 
12
 
13
  class GlaucomaModel(object):
14
  def __init__(self,
15
  cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
16
+ seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
17
  device=torch.device('cpu')):
18
  # where to load the model, gpu or cpu ?
19
  self.device = device
20
+ # classification model for glaucoma
21
  self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
22
  self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
23
+ # segmentation model for optic disc and cup
24
+ self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
25
+ self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
26
  # class activation map
27
  self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
28
 
29
  # classification id to label
30
+ self.cls_id2label = self.cls_model.config.id2label
31
+ # segmentation id to label
32
+ self.seg_id2label = self.seg_model.config.id2label
33
 
34
+ # number of classes for classification
35
+ self.num_diseases = len(self.cls_id2label)
36
+ # number of classes for segmentation
37
+ self.seg_classes = len(self.seg_id2label)
38
 
39
  def glaucoma_pred(self, image):
40
  """
 
48
  disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
49
 
50
  return disease_idx
51
+
52
+ def optic_disc_cup_pred(self, image):
53
+ """
54
+ Args:
55
+ image: image array in RGB order.
56
+ """
57
+ inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
58
+
59
+ with torch.no_grad():
60
+ inputs.to(self.device)
61
+ outputs = self.seg_model(**inputs)
62
+ logits = outputs.logits.cpu()
63
+
64
+ upsampled_logits = nn.functional.interpolate(
65
+ logits,
66
+ size=image.shape[:2],
67
+ mode="bilinear",
68
+ align_corners=False,
69
+ )
70
+
71
+ pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
72
+
73
+ return pred_disc_cup.numpy().astype(np.uint8)
74
 
75
  def process(self, image):
76
  """
 
81
  disease_idx = self.glaucoma_pred(image)
82
  cam = self.cam.get_cam(image, disease_idx)
83
  cam = cv2.resize(cam, image_shape[::-1])
84
+ disc_cup = self.optic_disc_cup_pred(image)
85
+ vcdr = simple_vcdr(disc_cup)
86
+ _, disc_cup_image = add_mask(image, disc_cup, [0, 1, 2], [[0, 0, 0], [0, 255, 0], [255, 0, 0]], 0.2)
87
 
88
+ return disease_idx, disc_cup_image, cam, vcdr
89