MingGatsby's picture
Update app.py
52b3445
raw
history blame contribute delete
No virus
12.3 kB
# Import required libraries
import os
import io
import torch
import pydicom
import numpy as np
import streamlit as st
# Import utility and custom functions
from PIL import Image
from Util.DICOM import DICOM_Utils
from Util.Custom_Model import Build_Custom_Model, reshape_transform
# Import additional MONAI and PyTorch Grad-CAM utilities
from monai.utils import set_determinism
from monai.networks.nets import SEResNet50
from monai.transforms import (
Activations,
EnsureChannelFirst,
AsDiscrete,
Compose,
RandFlip,
RandRotate,
RandZoom,
ScaleIntensity,
AsChannelFirst,
AddChannel,
RandSpatialCrop,
ScaleIntensityRangePercentiles,
Resize,
)
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
# (Int) Random seed
SEED = 0
# (Int) Model parameters
NUM_CLASSES = 1
# (String) CT Model directory
CT_MODEL_DIRECTORY = "models/CLOTS/CT"
# (String) MRI Model directory
MRI_MODEL_DIRECTORY = "models/CLOTS/MRI"
# (Boolean) Use custom model
CUSTOM_MODEL_FLAG = True
# (List[int]) Image size
SPATIAL_SIZE = [224, 224]
# (String) CT Model file name
CT_MODEL_FILE_NAME = "best_metric_model.pth"
# (String) MRI Model file name
MRI_MODEL_FILE_NAME = "best_metric_model.pth"
# (Boolean) List model modules
LIST_MODEL_MODULES = False
# (String) Model name
CT_MODEL_NAME = "swin_base_patch4_window7_224"
# (String) Model name
MRI_MODEL_NAME = "swin_base_patch4_window7_224"
# (Float) Model inference threshold
CT_INFERENCE_THRESHOLD = 0.5
# (Float) Model inference threshold
MRI_INFERENCE_THRESHOLD = 0.5
# (Int) Display CAM Class ID
CAM_CLASS_ID = 0
# (Int) Window Center for image display
DEFAULT_CT_WINDOW_CENTER = 40
# (Int) Window Width for image display
DEFAULT_CT_WINDOW_WIDTH = 100
# (Int) Window Center for image display
DEFAULT_MRI_WINDOW_CENTER = 400
# (Int) Window Width for image display
DEFAULT_MRI_WINDOW_WIDTH = 1000
# (Int) Minimum value for Window Center
WINDOW_CENTER_MIN = -600
# (Int) Maximum value for Window Center
WINDOW_CENTER_MAX = 1000
# (Int) Minimum value for Window Width
WINDOW_WIDTH_MIN = 1
# (Int) Maximum value for Window Width
WINDOW_WIDTH_MAX = 3000
# Evaluation Transforms
eval_transforms = Compose(
[
AsChannelFirst(),
ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True),
Resize(spatial_size=SPATIAL_SIZE)
]
)
# CAM Transforms
cam_transforms = Compose(
[
AsChannelFirst(),
Resize(spatial_size=SPATIAL_SIZE)
]
)
# Original Transforms
original_transforms = Compose(
[
AsChannelFirst()
]
)
# Function to convert PIL Image to byte stream in PNG format for downloading
def image_to_bytes(image):
byte_stream = io.BytesIO()
image.save(byte_stream, format='PNG')
return byte_stream.getvalue()
# Convert the file size from bytes to megabytes
def bytes_to_megabytes(file_size_bytes):
# Convert bytes to MB (1 MB = 1024 * 1024 bytes)
file_size_megabytes = round(file_size_bytes / (1024 * 1024), 2)
return str(file_size_megabytes) + " MB" # Rounding to 2 decimal places for readability
def meta_tensor_to_numpy(meta_tensor):
"""
Convert a PyTorch MetaTensor to a NumPy array
"""
# Ensure the MetaTensor is on the CPU
meta_tensor = meta_tensor.cpu()
# Convert the MetaTensor to a PyTorch tensor
torch_tensor = meta_tensor.to(dtype=torch.float32)
# Convert the PyTorch tensor to a NumPy array
numpy_array = torch_tensor.detach().numpy()
return numpy_array
set_determinism(seed=SEED)
torch.manual_seed(SEED)
# Parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_CUDA = False
if device == torch.device("cuda"):
USE_CUDA = True
def load_model(root_dir, model_name, model_file_name):
if CUSTOM_MODEL_FLAG:
model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
else:
model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device))
model.eval()
return model
ct_model = load_model(CT_MODEL_DIRECTORY, CT_MODEL_NAME, CT_MODEL_FILE_NAME)
mri_model = load_model(MRI_MODEL_DIRECTORY, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME)
if LIST_MODEL_MODULES:
for ct_name, _ in ct_model.named_modules():
print(ct_name)
for mri_name, _ in mri_model.named_modules():
print(mri_name)
# Initialize Streamlit
st.title("Analyze")
# Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH
st.sidebar.header("Windowing Parameters for DICOM")
MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1)
MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1)
CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1)
CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1)
uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
if uploaded_mri_file is not None:
# Read DICOM file into NumPy array
dicom_data = pydicom.dcmread(uploaded_mri_file)
dicom_array = dicom_data.pixel_array
# Convert the data type to float32
dicom_array = dicom_array.astype(np.float32)
# Then add a channel dimension
dicom_array = dicom_array[:, :, np.newaxis]
# To check file details
file_details = {"File_Name": uploaded_mri_file.name, "File_Type": uploaded_mri_file.type, "File_Size": bytes_to_megabytes(uploaded_mri_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))}
st.write(file_details)
transformed_array = eval_transforms(dicom_array)
# Convert to PyTorch tensor and move to device
image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy()
prob = outputs[0][0]
CLOTS_CLASSIFICATION = False
if(prob >= MRI_INFERENCE_THRESHOLD):
CLOTS_CLASSIFICATION=True
st.header("MRI Classification")
st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
st.subheader(f"Confidence : {prob * 100:.1f}%")
# Load the original DICOM image for download
download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device)
download_image_tensor = download_image_tensor.squeeze()
# Transform the download image and apply windowing
download_image_numpy = meta_tensor_to_numpy(download_image_tensor)
windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
# Streamlit button to trigger image download
image_data = image_to_bytes(Image.fromarray(windowed_download_image))
st.download_button(
label="Download MRI Image",
data=image_data,
file_name="downloaded_mri_image.png",
mime="image/png"
)
# Load the original DICOM image for display
display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device)
display_image_tensor = display_image_tensor.squeeze()
# Transform the image and apply windowing
display_image_numpy = meta_tensor_to_numpy(display_image_tensor)
windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
# Expand to three channels
windowed_image = np.expand_dims(windowed_image, axis=2)
windowed_image = np.tile(windowed_image, [1, 1, 3])
# Ensure both are of float32 type
windowed_image = windowed_image.astype(np.float32)
# Normalize to [0, 1] range
windowed_image = np.float32(windowed_image) / 255
# Build the CAM (Class Activation Map)
target_layers = [mri_model.model.norm]
cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA)
grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
grayscale_cam = grayscale_cam[0, :]
# Now you can safely call the show_cam_on_image function
visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)
uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
if uploaded_ct_file is not None:
# Read DICOM file into NumPy array
dicom_data = pydicom.dcmread(uploaded_ct_file)
dicom_array = dicom_data.pixel_array
# Convert the data type to float32
dicom_array = dicom_array.astype(np.float32)
# Then add a channel dimension
dicom_array = dicom_array[:, :, np.newaxis]
# To check file details
file_details = {"File_Name": uploaded_ct_file.name, "File_Type": uploaded_ct_file.type, "File_Size": bytes_to_megabytes(uploaded_ct_file.size), "File_Dimension": str((dicom_array.shape[0],dicom_array.shape[1]))}
st.write(file_details)
transformed_array = eval_transforms(dicom_array)
# Convert to PyTorch tensor and move to device
image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device)
# Predict
with torch.no_grad():
outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
prob = outputs[0][0]
CLOTS_CLASSIFICATION = False
if(prob >= CT_INFERENCE_THRESHOLD):
CLOTS_CLASSIFICATION=True
st.header("CT Classification")
st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
st.subheader(f"Confidence : {prob * 100:.1f}%")
# Load the original DICOM image for download
download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device)
download_image_tensor = download_image_tensor.squeeze()
# Transform the download image and apply windowing
download_image_numpy = meta_tensor_to_numpy(download_image_tensor)
windowed_download_image = DICOM_Utils.apply_windowing(download_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
# Streamlit button to trigger image download
image_data = image_to_bytes(Image.fromarray(windowed_download_image))
st.download_button(
label="Download CT Image",
data=image_data,
file_name="downloaded_ct_image.png",
mime="image/png"
)
# Load the original DICOM image for display
display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device)
display_image_tensor = display_image_tensor.squeeze()
# Transform the image and apply windowing
display_image_numpy = meta_tensor_to_numpy(display_image_tensor)
windowed_image = DICOM_Utils.apply_windowing(display_image_numpy, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
# Expand to three channels
windowed_image = np.expand_dims(windowed_image, axis=2)
windowed_image = np.tile(windowed_image, [1, 1, 3])
# Ensure both are of float32 type
windowed_image = windowed_image.astype(np.float32)
# Normalize to [0, 1] range
windowed_image = np.float32(windowed_image) / 255
# Build the CAM (Class Activation Map)
target_layers = [ct_model.model.norm]
cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=USE_CUDA)
grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
grayscale_cam = grayscale_cam[0, :]
# Now you can safely call the show_cam_on_image function
visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)