ferferefer commited on
Commit
f545fa4
1 Parent(s): 2c57916

Delete glaucoma.py

Browse files
Files changed (1) hide show
  1. glaucoma.py +0 -92
glaucoma.py DELETED
@@ -1,92 +0,0 @@
1
- import cv2
2
- import torch
3
-
4
- import numpy as np
5
-
6
- from torch import nn
7
- from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
8
-
9
- from cam import ClassActivationMap
10
- from utils import add_mask, simple_vcdr
11
-
12
-
13
- class GlaucomaModel(object):
14
- def __init__(self,
15
- cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
16
- seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
17
- device=torch.device('cpu')):
18
- # where to load the model, gpu or cpu ?
19
- self.device = device
20
- # classification model for glaucoma
21
- self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
22
- self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
23
- # segmentation model for optic disc and cup
24
- self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
25
- self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
26
- # class activation map
27
- self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
28
-
29
- # classification id to label
30
- self.cls_id2label = self.cls_model.config.id2label
31
- # segmentation id to label
32
- self.seg_id2label = self.seg_model.config.id2label
33
-
34
- # number of classes for classification
35
- self.num_diseases = len(self.cls_id2label)
36
- # number of classes for segmentation
37
- self.seg_classes = len(self.seg_id2label)
38
-
39
- def glaucoma_pred(self, image):
40
- """
41
- Args:
42
- image: image array in RGB order.
43
- """
44
- inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
45
- with torch.no_grad():
46
- inputs.to(self.device)
47
- outputs = self.cls_model(**inputs).logits
48
- disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
49
-
50
- return disease_idx
51
-
52
- def optic_disc_cup_pred(self, image):
53
- """
54
- Args:
55
- image: image array in RGB order.
56
- """
57
- inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
58
-
59
- with torch.no_grad():
60
- inputs.to(self.device)
61
- outputs = self.seg_model(**inputs)
62
- logits = outputs.logits.cpu()
63
-
64
- upsampled_logits = nn.functional.interpolate(
65
- logits,
66
- size=image.shape[:2],
67
- mode="bilinear",
68
- align_corners=False,
69
- )
70
-
71
- pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
72
-
73
- return pred_disc_cup.numpy().astype(np.uint8)
74
-
75
- def process(self, image):
76
- """
77
- Args:
78
- image: image array in RGB order.
79
- """
80
- image_shape = image.shape[:2]
81
- disease_idx = self.glaucoma_pred(image)
82
- cam = self.cam.get_cam(image, disease_idx)
83
- cam = cv2.resize(cam, image_shape[::-1])
84
- disc_cup = self.optic_disc_cup_pred(image)
85
- try:
86
- vcdr = simple_vcdr(disc_cup)
87
- except:
88
- vcdr = np.nan
89
- _, disc_cup_image = add_mask(image, disc_cup, [1, 2], [[0, 255, 0], [255, 0, 0]], 0.2)
90
-
91
- return disease_idx, disc_cup_image, cam, vcdr
92
-