File size: 5,447 Bytes
be2c585 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import streamlit as st
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import torch.nn as nn
from PathDino import get_pathDino_model
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load PathDino model and image transforms
model, image_transforms = get_pathDino_model("PathDino512.pth")
st.sidebar.markdown("### PathDino")
st.sidebar.markdown(
"PathDino is a lightweight histology transformer consisting of just five small vision transformer blocks. "
"PathDino is a customized ViT architecture, finely tuned to the nuances of histological images. It not only exhibits "
"superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
"image analysis.\n\n"
)
default_image_url_compare = "images/HistRotate.png"
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=500)
default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=500)
default_image_url_compare = "images/ActivationMap.png"
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=500)
def visualize_attention_ViT(model, img, patch_size=16):
attention_list = []
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size
attentions = model.get_last_selfattention(img.to(device))
nh = attentions.shape[1] # number of head
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
for j in range(nh):
attention_list.append(attentions[j])
return attention_list
# Define the function to generate activation maps
def generate_activation_maps(image):
preprocess = transforms.Compose([
transforms.Resize((512, 512)),
transforms.CenterCrop(512),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
])
image_tensor = preprocess(image)
img = image_tensor.unsqueeze(0).to(device)
# Generate activation maps
with torch.no_grad():
attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
return attention_list
# Streamlit UI
st.title("PathDino - Compact ViT for Histolopathology Image Analysis")
st.write("Upload a histology image to view the activation maps.")
# uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
uploaded_image = "images/HistRotate.png"
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_image is not None:
columns = st.columns(3)
columns[1].image(uploaded_image, caption="Uploaded Image", width=300)
# Load the image and apply preprocessing
uploaded_image = Image.open(uploaded_image).convert('RGB')
attention_list = generate_activation_maps(uploaded_image)
print(len(attention_list))
st.subheader(f"Attention Maps of the input image")
columns = st.columns(len(attention_list)//2)
columns2 = st.columns(len(attention_list)//2)
for index, col in enumerate(columns):
# Create a plot
plt.plot(512, 512)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close()
for index, col in enumerate(columns2):
index = index + len(attention_list)//2
# Create a plot
plt.plot(512, 512)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close() |