glaucoma_screening / glaucoma.py
pamixsun's picture
Update glaucoma.py
0995b85
raw history blame
No virus
3.18 kB
import cv2
import torch
import numpy as np
from torch import nn
from transformers import AutoImageProcessor, Swinv2ForImageClassification, SegformerForSemanticSegmentation
from cam import ClassActivationMap
from utils import add_mask, simple_vcdr
class GlaucomaModel(object):
def __init__(self,
cls_model_path="pamixsun/swinv2_tiny_for_glaucoma_classification",
seg_model_path='pamixsun/segformer_for_optic_disc_cup_segmentation',
device=torch.device('cpu')):
# where to load the model, gpu or cpu ?
self.device = device
# classification model for glaucoma
self.cls_extractor = AutoImageProcessor.from_pretrained(cls_model_path)
self.cls_model = Swinv2ForImageClassification.from_pretrained(cls_model_path).to(device).eval()
# segmentation model for optic disc and cup
self.seg_extractor = AutoImageProcessor.from_pretrained(seg_model_path)
self.seg_model = SegformerForSemanticSegmentation.from_pretrained(seg_model_path).to(device).eval()
# class activation map
self.cam = ClassActivationMap(self.cls_model, self.cls_extractor)
# classification id to label
self.cls_id2label = self.cls_model.config.id2label
# segmentation id to label
self.seg_id2label = self.seg_model.config.id2label
# number of classes for classification
self.num_diseases = len(self.cls_id2label)
# number of classes for segmentation
self.seg_classes = len(self.seg_id2label)
def glaucoma_pred(self, image):
"""
Args:
image: image array in RGB order.
"""
inputs = self.cls_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.cls_model(**inputs).logits
disease_idx = outputs.cpu()[0, :].detach().numpy().argmax()
return disease_idx
def optic_disc_cup_pred(self, image):
"""
Args:
image: image array in RGB order.
"""
inputs = self.seg_extractor(images=image.copy(), return_tensors="pt")
with torch.no_grad():
inputs.to(self.device)
outputs = self.seg_model(**inputs)
logits = outputs.logits.cpu()
upsampled_logits = nn.functional.interpolate(
logits,
size=image.shape[:2],
mode="bilinear",
align_corners=False,
)
pred_disc_cup = upsampled_logits.argmax(dim=1)[0]
return pred_disc_cup.numpy().astype(np.uint8)
def process(self, image):
"""
Args:
image: image array in RGB order.
"""
image_shape = image.shape[:2]
disease_idx = self.glaucoma_pred(image)
cam = self.cam.get_cam(image, disease_idx)
cam = cv2.resize(cam, image_shape[::-1])
disc_cup = self.optic_disc_cup_pred(image)
vcdr = simple_vcdr(disc_cup)
_, disc_cup_image = add_mask(image, disc_cup, [0, 1, 2], [[0, 0, 0], [0, 255, 0], [255, 0, 0]], 0.2)
return disease_idx, disc_cup_image, cam, vcdr