PathDino / app.py
Saghir's picture
Update app.py
0118a2c verified
import streamlit as st
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import numpy as np
from PathDino import get_pathDino_model
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load PathDino model and image transforms
model, image_transforms = get_pathDino_model("PathDino512.pth")
st.sidebar.markdown("### PathDino")
st.sidebar.markdown(
"PathDino is a lightweight histopathology transformer consisting of just five small vision transformer blocks. "
"PathDino is a customized ViT architecture, finely tuned to the nuances of histology images. It not only exhibits "
"superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
"image analysis.\n\n"
)
default_image_url_compare = "images/HistRotate.png"
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=300)
default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=300)
default_image_url_compare = "images/ActivationMap.png"
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=300)
st.sidebar.markdown("### Citation")
# Create a code block for citations
st.sidebar.markdown("""
```markdown
@article{alfasly2023rotationagnostic,
title={Rotation-Agnostic Image Representation Learning for Digital Pathology},
author={Saghir Alfasly and Abubakr Shafique and Peyman Nejat and Jibran Khan and Areej Alsaafin and Ghazal Alabtah and H. R. Tizhoosh},
year={2023},
eprint={2311.08359},
archivePrefix={arXiv},
primaryClass={cs.CV}
}""")
st.sidebar.markdown("\n\n")
st.sidebar.markdown("KIMIA Lab, Department of Artificial Intelligence and Informatics, \n Mayo Clinic, \n Rochester, MN, USA")
def visualize_attention_ViT(model, img, patch_size=16):
attention_list = []
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size
attentions = model.get_last_selfattention(img.to(device))
nh = attentions.shape[1] # number of head
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
for j in range(nh):
attention_list.append(attentions[j])
return attention_list
# Define the function to generate activation maps
def generate_activation_maps(image, patch_size=16):
# Convert the image to a NumPy array
img = np.array(image)
# make the image divisible by the patch size
w, h = img.shape[1] - img.shape[0] % patch_size, img.shape[1] - img.shape[1] % patch_size
print("w, h:", w, h)
# min_size = min(w, h)
print("Image shape:", img.shape)
preprocess = transforms.Compose([
transforms.Resize((img.shape[0], img.shape[1])),
transforms.CenterCrop((w, h)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
])
image_tensor = preprocess(image)
img = image_tensor.unsqueeze(0).to(device)
# Generate activation maps
with torch.no_grad():
attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
return attention_list
# Streamlit UI
st.title("PathDino - Compact ViT for Histopathology Image Analysis")
st.write("Upload a histology image to view the attention maps.")
# uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
uploaded_image = "images/HistRotate.png"
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
if uploaded_image is not None:
# columns = st.columns(3)
st.image(uploaded_image, caption="Uploaded Image", width=500)
# Load the image and apply preprocessing
uploaded_image = Image.open(uploaded_image).convert('RGB')
attention_list = generate_activation_maps(uploaded_image)
print(len(attention_list))
st.subheader(f"Attention Maps of the input image")
columns = st.columns(2)
columns2 = st.columns(2)
columns3 = st.columns(2)
# for index in range(6):
for index, col in enumerate(columns):
# Create a plot
plt.plot(600, 600)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
print(type(attention_list[index]))
print(attention_list[index].shape)
plt.imshow(attention_list[index])
col.pyplot(plt)
# col
# st.image(plt, caption=f"Head-{index+1}", width=display_w)
plt.close()
for index, col in enumerate(columns2):
index = index + 2
# Create a plot
plt.plot(600, 600)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close()
for index, col in enumerate(columns3):
index = index + 4
# Create a plot
plt.plot(600, 600)
# Remove x and y axis labels
plt.xticks([]) # Hide x-axis ticks and labels
plt.yticks([]) # Hide y-axis ticks and labels
# Alternatively, if you only want to hide the labels and keep the ticks:
plt.gca().axes.get_xaxis().set_visible(False)
plt.gca().axes.get_yaxis().set_visible(False)
plt.imshow(attention_list[index])
col.pyplot(plt)
plt.close()