|
|
|
import streamlit as st |
|
import pandas as pd |
|
import plotly.express as px |
|
import plotly.graph_objects as go |
|
import numpy as np |
|
|
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torchvision.models.segmentation import deeplabv3_resnet50 |
|
from torchvision.transforms.functional import to_tensor |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
|
import lib.utils as libUtils |
|
|
|
import sys |
|
import os, random, io |
|
|
|
description = "Diagnosis" |
|
m_kblnTraceOn = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
NUM_CLASSES = 3 |
|
|
|
|
|
|
|
BESTMODEL_PATH = r"model.pth" |
|
MODEL_FULLPATH = 'bin/models/' + BESTMODEL_PATH |
|
model_path = MODEL_FULLPATH |
|
|
|
DEFAULT_DEVICE_TYPE = ('cuda' if torch.cuda.is_available() else 'cpu') |
|
DEFAULT_BACKBONE_MODEL = 'r50' |
|
backbone_model_name = DEFAULT_BACKBONE_MODEL |
|
|
|
|
|
|
|
def image_toBytesIO(image: Image) -> bytes: |
|
|
|
imgByteArr = io.BytesIO() |
|
|
|
|
|
image.save(imgByteArr, format=image.format) |
|
|
|
return imgByteArr |
|
|
|
|
|
def image_toByteArray(image: Image) -> bytes: |
|
|
|
imgByteArr = image_toBytesIO(image) |
|
|
|
|
|
imgByteArr = imgByteArr.getvalue() |
|
return imgByteArr |
|
|
|
|
|
def run(): |
|
|
|
global m_kbln_TraceOn |
|
print("\nINFO (litDiagnosis.run) loading ", description, " page ...") |
|
|
|
|
|
|
|
if (m_kblnTraceOn): print("TRACE1 (litDiagnosis.run): Initialize Page Settings ...") |
|
|
|
st.markdown("#### Single Tile Diagnosis") |
|
|
|
|
|
imgUploaded = None |
|
if st.button("Random Sample"): |
|
|
|
strPth_sample = libUtils.pth_dtaTileSamples |
|
strFil_sample = random.choice(os.listdir(strPth_sample)) |
|
strFullPth_sample = os.path.join(strPth_sample, strFil_sample) |
|
|
|
print("INFO (litDiagnosis.run): sample file selected ... ", strFullPth_sample) |
|
|
|
|
|
imgSample = Image.open(strFullPth_sample) |
|
imgSample = image_toBytesIO(imgSample) |
|
imgUploaded = imgSample |
|
imgUploaded.name = strFil_sample |
|
imgUploaded.type = os.path.splitext(strFil_sample)[1] |
|
|
|
|
|
|
|
m_blnDisableDragDrop = False |
|
|
|
|
|
imgDropped = st.file_uploader("Upload a single Tile", |
|
type=["png", "jpg", "tif", "tiff", "img"], |
|
accept_multiple_files=False ) |
|
|
|
|
|
if (imgDropped is not None): |
|
imgUploaded = imgDropped |
|
|
|
|
|
if (imgUploaded is None): |
|
if (m_kblnTraceOn): print("ERROR (litDiagnosis.run): imgUploaded is None ...") |
|
else: |
|
try: |
|
|
|
if (m_kblnTraceOn): print("TRACE1 (litDiagnosis.run): Print uploaded file details ...") |
|
st.write("FileName:", "   ", imgUploaded.name) |
|
st.write("FileType:", "   ", imgUploaded.type) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
showDiagnosis_horiz(imgUploaded) |
|
|
|
except TypeError as e: |
|
print("ERROR (litDiagnosis.run_typeError1): ", e) |
|
|
|
except: |
|
e = sys.exc_info() |
|
print("ERROR (litDiagnosis.run_genError1): ", e) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
print("") |
|
|
|
except TypeError as e: |
|
print("ERROR (litDiagnosis.run_typeError2): ", e) |
|
|
|
except: |
|
e = sys.exc_info() |
|
print("ERROR (litDiagnosis.run_genError2): ", e) |
|
|
|
|
|
def showImg_wsi(img): |
|
print("") |
|
|
|
|
|
def readyModel_getPreds(imgDropped): |
|
print("TRACE: readyModel_getPreds ...") |
|
print("INFO: save raw tile ...") |
|
strPth_tilRaw = save_tilRaw(imgDropped) |
|
|
|
|
|
print("INFO: ready base model ...") |
|
mdlBase = readyBaseModel() |
|
print("INFO: ready model with weights ...") |
|
mdlWeights = readyModelWithWeights(mdlBase) |
|
print("INFO: ready model with xai ...") |
|
mdlXai = readyModelWithXAI(mdlWeights) |
|
|
|
|
|
print("INFO: get xai weighted pred ...") |
|
output_pred, tns_batch = predXai_tile(mdlXai, strPth_tilRaw) |
|
|
|
|
|
print("INFO: get GRADCAM preds ...") |
|
cam_img_bg, cam_img_wt, cam_img_vt = predGradCam_tile(output_pred, mdlXai, tns_batch) |
|
|
|
print("TRACE: return readyModel_getPreds ...") |
|
return strPth_tilRaw, output_pred, cam_img_bg, cam_img_wt, cam_img_vt |
|
|
|
|
|
def showDiagnosis_horiz(imgDropped): |
|
|
|
|
|
st.write("#") |
|
|
|
|
|
print("TRACE2: ready model ...") |
|
strPth_tilRaw, xai_pred, cam_img_bg, cam_img_wt, cam_img_vt = readyModel_getPreds(imgDropped) |
|
|
|
|
|
print("TRACE2: display raw preds, headers ...") |
|
colRaw, colPred, colGradBack, colGradWhole, colGradViable = st.columns(5) |
|
colRaw.write("Raw Tile") |
|
colPred.write("Prediction") |
|
colGradBack.write("GradCAM: Background") |
|
colGradWhole.write("GradCAM: Whole Tumor") |
|
colGradViable.write("GradCAM: Viable Tumor") |
|
|
|
|
|
colRaw, colPred, colGradBack, colGradWhole, colGradViable = st.columns(5) |
|
showCol_rawTil(colRaw, strPth_tilRaw) |
|
showCol_predTil(colPred, xai_pred[0], strPth_tilRaw) |
|
showCol_gradCamImg("imgGradCam_bg", colGradBack, cam_img_bg[0]) |
|
showCol_gradCamImg("imgGradCam_wt", colGradWhole, cam_img_wt[0]) |
|
showCol_gradCamImg("imgGradCam_vt", colGradViable, cam_img_vt[0]) |
|
|
|
|
|
def showCol_rawTil(colRaw, strPth_tilRaw): |
|
print("TRACE3: showCol_rawTil ...") |
|
colRaw.image(strPth_tilRaw, width=400, use_column_width=True) |
|
|
|
|
|
|
|
|
|
def showCol_predTil(colPred, xai_pred, strPth_tilRaw): |
|
kstrPth_tilePred = "data/tiles/pred/" |
|
strFilName = os.path.basename(strPth_tilRaw) |
|
strFil_tilePred = kstrPth_tilePred + strFilName |
|
|
|
|
|
ensureDirExists(kstrPth_tilePred) |
|
|
|
print("TRACE3: showCol_predTil2 ... ", strFil_tilePred) |
|
argmax_mask = torch.argmax(xai_pred, dim=0) |
|
preds = argmax_mask.cpu().squeeze().numpy() |
|
|
|
cmap = plt.cm.get_cmap('tab10', 3) |
|
print("TRACE3: typeOf(preds) ...", type(preds)) |
|
|
|
print("TRACE3: save pred image ...") |
|
plt.imsave(strFil_tilePred, preds, cmap=cmap, vmin=0, vmax=2) |
|
|
|
print("TRACE3: load image ...", strFil_tilePred) |
|
colPred.image(strFil_tilePred, width=400, use_column_width=True) |
|
|
|
|
|
def showCol_gradCamImg(strImgContext, colGradCam, cam_img): |
|
print("TRACE3: showCol_gradImg ... ", strImgContext) |
|
imgGradCam = Image.fromarray(cam_img) |
|
colGradCam.image(imgGradCam, width=400, use_column_width=True) |
|
|
|
|
|
def showDiagnosis_vert(imgDropped): |
|
|
|
|
|
st.write("#") |
|
|
|
|
|
strPth_tilRaw, xai_pred, cam_img_bg, cam_img_wt, cam_img_vt = readyModel_getPreds(imgDropped) |
|
|
|
|
|
''' |
|
strPth_tilPred = save_tilPred(output_pred) |
|
strPth_tilGradBg = save_tilGradBg(cam_img_bg) |
|
strPth_tilGradWt = None |
|
strPth_tilGradVt = None |
|
''' |
|
|
|
|
|
lstDescr = ["Raw Tile", "Prediction", "GradCam: Background", "GradCam: Whole Tumor", "GradCam: Viable Tumor"] |
|
lstImages = [strPth_tilRaw, strPth_tilRaw, strPth_tilRaw, strPth_tilRaw, strPth_tilRaw] |
|
|
|
|
|
for imgIdx in range(len(lstImages)): |
|
colDescr, colImage = st.columns([0.25, 0.75]) |
|
colDescr.write(lstDescr[imgIdx]) |
|
colImage.image(lstImages[imgIdx], width=400, use_column_width=True) |
|
|
|
|
|
def ensureDirExists(strPth): |
|
blnExists = os.path.exists(strPth) |
|
if not blnExists: |
|
os.makedirs(strPth) |
|
print("TRACE: creating dir ... ", strPth) |
|
|
|
|
|
def save_tilRaw(imgDropped): |
|
print("TRACE: save_tilRaw ...") |
|
|
|
kstrPth_tileRaw = "data/tiles/raw/" |
|
strFil_tileRaw = kstrPth_tileRaw + imgDropped.name |
|
print("TRACE: save_tilRaw.file ... ", strFil_tileRaw) |
|
|
|
|
|
ensureDirExists(kstrPth_tileRaw) |
|
|
|
|
|
if (os.path.isfile(strFil_tileRaw)): |
|
print("WARN: save_tilRaw.file exists; delete ... ", strFil_tileRaw) |
|
os.remove(strFil_tileRaw) |
|
|
|
with open(strFil_tileRaw,"wb") as filUpload: |
|
filUpload.write(imgDropped.getbuffer()) |
|
print("TRACE: uploaded file saved to ", strFil_tileRaw) |
|
return strFil_tileRaw |
|
|
|
|
|
def prepare_model(backbone_model="mbv3", num_classes=2): |
|
|
|
|
|
weights = 'DEFAULT' |
|
if backbone_model == "mbv3": |
|
model = None |
|
|
|
|
|
elif backbone_model == "r50": |
|
model = deeplabv3_resnet50(weights=weights) |
|
|
|
elif backbone_model == "r101": |
|
model = None |
|
|
|
|
|
else: |
|
raise ValueError("Wrong backbone model passed. Must be one of 'mbv3', 'r50' and 'r101' ") |
|
|
|
|
|
|
|
model.classifier[-1] = nn.Conv2d(model.classifier[-1].in_channels, num_classes, kernel_size=1) |
|
model.aux_classifier[-1] = nn.Conv2d(model.aux_classifier[-1].in_channels, num_classes, kernel_size=1) |
|
return model |
|
|
|
|
|
|
|
def intermediate_metric_calculation(predictions, targets, use_dice=False, |
|
smooth=1e-6, dims=(2, 3)): |
|
|
|
|
|
|
|
intersection = (predictions * targets).sum(dim=dims) + smooth |
|
|
|
|
|
summation = (predictions.sum(dim=dims) + targets.sum(dim=dims)) + smooth |
|
|
|
if use_dice: |
|
|
|
metric = (2.0 * intersection) / summation |
|
else: |
|
|
|
union = summation - intersection + smooth |
|
|
|
|
|
metric = intersection / union |
|
|
|
|
|
|
|
total = metric.mean() |
|
|
|
|
|
return total |
|
|
|
|
|
def convert_2_onehot(matrix, num_classes=3): |
|
''' |
|
Perform one-hot encoding across the channel dimension. |
|
''' |
|
matrix = matrix.permute(0, 2, 3, 1) |
|
matrix = torch.argmax(matrix, dim=-1) |
|
matrix = torch.nn.functional.one_hot(matrix, num_classes=num_classes) |
|
matrix = matrix.permute(0, 3, 1, 2) |
|
return matrix |
|
|
|
|
|
|
|
class Loss(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, predictions, targets): |
|
|
|
|
|
targets = torch.argmax(targets, dim=1) |
|
pixel_loss = F.cross_entropy(predictions, targets, reduction="mean") |
|
|
|
return pixel_loss |
|
|
|
|
|
class Metric(nn.Module): |
|
def __init__(self, num_classes=3, smooth=1e-6, use_dice=False): |
|
super().__init__() |
|
self.num_classes = num_classes |
|
self.smooth = smooth |
|
self.use_dice = use_dice |
|
|
|
def forward(self, predictions, targets): |
|
|
|
|
|
|
|
|
|
|
|
predictions = convert_2_onehot(predictions, num_classes=self.num_classes) |
|
metric = intermediate_metric_calculation(predictions, targets, use_dice=self.use_dice, smooth=self.smooth) |
|
|
|
|
|
return metric |
|
|
|
|
|
def get_default_device(): |
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def readyBaseModel(): |
|
|
|
|
|
device = get_default_device() |
|
model = prepare_model(backbone_model=backbone_model_name, num_classes=NUM_CLASSES) |
|
|
|
metric_name = "iou" |
|
use_dice = (True if metric_name == "dice" else False) |
|
metric_fn = Metric(num_classes=NUM_CLASSES, use_dice=use_dice).to(device) |
|
loss_fn = Loss().to(device) |
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) |
|
|
|
return model |
|
|
|
|
|
def readyModelWithWeights(mdlBase): |
|
print("TRACE: loading model with weights ... ", model_path) |
|
mdlBase.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
model_with_weights = mdlBase |
|
model_with_weights.eval() |
|
return model_with_weights |
|
|
|
|
|
class SegmentationModelOutputWrapper(torch.nn.Module): |
|
def __init__(self, model): |
|
super(SegmentationModelOutputWrapper, self).__init__() |
|
self.model = model |
|
|
|
def forward(self, x): |
|
return self.model(x)["out"] |
|
|
|
|
|
def readyModelWithXAI(mdlWeighted): |
|
model_xai = SegmentationModelOutputWrapper(mdlWeighted) |
|
|
|
model_xai.eval() |
|
model_xai.to('cpu') |
|
return model_xai |
|
|
|
|
|
|
|
def val_filToTensor(strPth_fil): |
|
img_fil = Image.open(strPth_fil) |
|
img_fil = img_fil.convert("RGB") |
|
img_fil = np.asarray(img_fil)/255 |
|
return to_tensor(img_fil).unsqueeze(0) |
|
|
|
|
|
|
|
def val_aryToTensor(pth_fil, ary_fils): |
|
aryTensor = [] |
|
for str_filName in ary_fils: |
|
aryTensor.append(val_filToTensor(pth_fil, str_filName)) |
|
return aryTensor |
|
|
|
|
|
def predXai_tile(mdl_xai, strPth_tileRaw): |
|
|
|
print("TRACE: get tensor from single file ... ", strPth_tileRaw) |
|
val_tensorFil = val_filToTensor(strPth_tileRaw) |
|
val_tensorBatch = val_tensorFil |
|
|
|
print("TRACE: get mdl_xai prediction ...") |
|
output = mdl_xai(val_tensorBatch.float().to('cpu')) |
|
|
|
print("TRACE: predXai_tile return ...") |
|
return output, val_tensorBatch |
|
|
|
|
|
class SemanticSegmentationTarget: |
|
def __init__(self, category, mask): |
|
self.category = category |
|
self.mask = torch.from_numpy(mask) |
|
if torch.cuda.is_available(): |
|
self.mask = self.mask.cuda() |
|
|
|
def __call__(self, model_output): |
|
return (model_output[self.category, :, : ] * self.mask).sum() |
|
|
|
|
|
def predGradCam_tile(output_xaiPred, mdl_xai, val_image_batch): |
|
print("TRACE: predGradCam initialize ...") |
|
cam_img_bg = [] |
|
cam_img_wt = [] |
|
cam_img_vt = [] |
|
|
|
sem_classes = ['__background__', 'whole_tumor', 'viable_tumor'] |
|
sem_class_to_idx = {cls: idx for (idx, cls) in enumerate(sem_classes)} |
|
|
|
argmax_mask = torch.argmax(output_xaiPred, dim=1) |
|
argmax_mask_np = argmax_mask.cpu().squeeze().numpy() |
|
preds = argmax_mask_np |
|
|
|
seg_mask = preds |
|
bg_category = sem_class_to_idx["__background__"] |
|
bg_mask_float = np.float32(seg_mask == bg_category) |
|
wt_category = sem_class_to_idx["whole_tumor"] |
|
wt_mask_float = np.float32(seg_mask == wt_category) |
|
vt_category = sem_class_to_idx["viable_tumor"] |
|
vt_mask_float = np.float32(seg_mask == vt_category) |
|
|
|
target_layers = [mdl_xai.model.backbone.layer4] |
|
|
|
for i in range(len(val_image_batch)): |
|
rgb_img = np.float32(val_image_batch[i].permute(1, 2, 0)) |
|
rgb_tensor = val_image_batch[i].unsqueeze(0).float() |
|
|
|
print("TRACE: process the background ...") |
|
targets = [SemanticSegmentationTarget(bg_category, bg_mask_float[i])] |
|
with GradCAM(model=mdl_xai, |
|
target_layers=target_layers, |
|
use_cuda=torch.cuda.is_available()) as cam: |
|
|
|
grayscale_cam = cam(input_tensor = rgb_tensor, |
|
targets = targets)[0, :] |
|
cam_img_bg.append(show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)) |
|
|
|
print("TRACE: process whole tumors ...") |
|
targets = [SemanticSegmentationTarget(wt_category, wt_mask_float[i])] |
|
with GradCAM(model=mdl_xai, |
|
target_layers=target_layers, |
|
use_cuda=torch.cuda.is_available()) as cam: |
|
|
|
grayscale_cam = cam(input_tensor = rgb_tensor, |
|
targets = targets)[0, :] |
|
cam_img_wt.append(show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)) |
|
|
|
print("TRACE: process viable tumors ...") |
|
targets = [SemanticSegmentationTarget(vt_category, vt_mask_float[i])] |
|
with GradCAM(model=mdl_xai, |
|
target_layers=target_layers, |
|
use_cuda=torch.cuda.is_available()) as cam: |
|
|
|
grayscale_cam = cam(input_tensor = rgb_tensor, |
|
targets = targets)[0, :] |
|
cam_img_vt.append(show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)) |
|
|
|
return cam_img_bg, cam_img_wt, cam_img_vt |
|
|