File size: 3,189 Bytes
0499403
 
 
5bfee0b
0499403
5bfee0b
 
 
 
 
0499403
 
 
 
 
5bfee0b
0499403
 
 
5bfee0b
0499403
 
5bfee0b
 
 
0499403
 
 
 
5bfee0b
 
 
0499403
5bfee0b
 
 
 
0499403
 
 
 
 
 
 
 
 
 
 
 
 
5bfee0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0499403
 
 
 
 
 
 
 
 
 
5bfee0b
 
 
0499403
5bfee0b
0499403
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import cv2
import torch

import numpy as np

from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation

from lib.cam import ClassActivationMap
from lib.utils import add_mask, simple_vcdr


class GlaucomaModel(object):
    def __init__(self, 
                 cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification", 
                 seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
                 device=torch.device('cpu')):
        # where to load the model, gpu or cpu ?
        self.device = device
        # classification model for glaucoma
        self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
        self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
        # segmentation model for optic disc and cup
        self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
        self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
        # class activation map
        self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)

        # classification id to label
        self.cls_id2label = self.cls_model.config.id2label
        # segmentation id to label
        self.seg_id2label = self.seg_model.config.id2label

        # number of classes for classification
        self.num_diseases = len(self.cls_id2label)
        # number of classes for segmentation
        self.seg_classes = len(self.seg_id2label)

    def glaucoma_pred(self, image):
        """
        Args:
            image: image array in RGB order.
        """
        inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.cls_model(**inputs).logits
            disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()

        return disease_idx

    def optic_disc_cup_pred(self, image):
        """
        Args:
            image: image array in RGB order.
        """
        inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")

        with torch.no_grad():
            inputs.to(self.device)
            outputs = self.seg_model(**inputs)
        logits = outputs.logits.cpu()

        upsampled_logits = nn.functional.interpolate(
            logits,
            size=image.shape[:2],
            mode="bilinear",
            align_corners=False,
        )

        pred_disc_cup = upsampled_logits.argmax(dim=1)[0]

        return pred_disc_cup.numpy().astype(np.uint8)
    
    def process(self, image):
        """
        Args:
            image: image array in RGB order.
        """
        image_shape = image.shape[:2]
        disease_idx = self.glaucoma_pred(image)
        cam = self.cam.get_cam(image, disease_idx)
        cam = cv2.resize(cam, image_shape[::-1])
        disc_cup = self.optic_disc_cup_pred(image)
        vcdr = simple_vcdr(disc_cup)
        _, disc_cup_image = add_mask(image, disc_cup, [0, 1, 2], [[0, 0, 0], [0, 255, 0], [255, 0, 0]], 0.2)

        return disease_idx, disc_cup_image, cam, vcdr