File size: 3,189 Bytes
0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 5bfee0b 0499403 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import cv2
import torch
import numpy as np
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
from lib.cam import ClassActivationMap
from lib.utils import add_mask, simple_vcdr
class GlaucomaModel(object):
def __init__(self,
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
device=torch.device('cpu')):
# where to load the model, gpu or cpu ?
self.device = device
# classification model for glaucoma
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
# segmentation model for optic disc and cup
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
# class activation map
self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
# classification id to label
self.cls_id2label = self.cls_model.config.id2label
# segmentation id to label
self.seg_id2label = self.seg_model.config.id2label
# number of classes for classification
self.num_diseases = len(self.cls_id2label)
# number of classes for segmentation
self.seg_classes = len(self.seg_id2label)
def glaucoma_pred(self, image):
"""
Args:
image: image array in RGB order.
"""
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.cls_model(**inputs).logits
disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
return disease_idx
def optic_disc_cup_pred(self, image):
"""
Args:
image: image array in RGB order.
"""
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.seg_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.shape[:2],
mode="bilinear",
align_corners=False,
)
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
return pred_disc_cup.numpy().astype(np.uint8)
def process(self, image):
"""
Args:
image: image array in RGB order.
"""
image_shape = image.shape[:2]
disease_idx = self.glaucoma_pred(image)
cam = self.cam.get_cam(image, disease_idx)
cam = cv2.resize(cam, image_shape[::-1])
disc_cup = self.optic_disc_cup_pred(image)
vcdr = simple_vcdr(disc_cup)
_, disc_cup_image = add_mask(image, disc_cup, [0, 1, 2], [[0, 0, 0], [0, 255, 0], [255, 0, 0]], 0.2)
return disease_idx, disc_cup_image, cam, vcdr
|