import streamlit as st import torch import matplotlib.pyplot as plt from torchvision import transforms from PIL import Image import torch.nn as nn import numpy as np from PathDino import get_pathDino_model import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load PathDino model and image transforms model, image_transforms = get_pathDino_model("PathDino512.pth") st.sidebar.markdown("### PathDino") st.sidebar.markdown( "PathDino is a lightweight histopathology transformer consisting of just five small vision transformer blocks. " "PathDino is a customized ViT architecture, finely tuned to the nuances of histology images. It not only exhibits " "superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology " "image analysis.\n\n" ) default_image_url_compare = "images/HistRotate.png" st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=300) default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png" st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=300) default_image_url_compare = "images/ActivationMap.png" st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=300) st.sidebar.markdown("### Citation") # Create a code block for citations st.sidebar.markdown(""" ```markdown @article{alfasly2023rotationagnostic, title={Rotation-Agnostic Image Representation Learning for Digital Pathology}, author={Saghir Alfasly and Abubakr Shafique and Peyman Nejat and Jibran Khan and Areej Alsaafin and Ghazal Alabtah and H. R. Tizhoosh}, year={2023}, eprint={2311.08359}, archivePrefix={arXiv}, primaryClass={cs.CV} }""") st.sidebar.markdown("\n\n") st.sidebar.markdown("KIMIA Lab, Department of Artificial Intelligence and Informatics, \n Mayo Clinic, \n Rochester, MN, USA") def visualize_attention_ViT(model, img, patch_size=16): attention_list = [] device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") w_featmap = img.shape[-2] // patch_size h_featmap = img.shape[-1] // patch_size attentions = model.get_last_selfattention(img.to(device)) nh = attentions.shape[1] # number of head # we keep only the output patch attention attentions = attentions[0, :, 0, 1:].reshape(nh, -1) attentions = attentions.reshape(nh, w_featmap, h_featmap) attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy() for j in range(nh): attention_list.append(attentions[j]) return attention_list # Define the function to generate activation maps def generate_activation_maps(image, patch_size=16): # Convert the image to a NumPy array img = np.array(image) # make the image divisible by the patch size w, h = img.shape[1] - img.shape[0] % patch_size, img.shape[1] - img.shape[1] % patch_size print("w, h:", w, h) # min_size = min(w, h) print("Image shape:", img.shape) preprocess = transforms.Compose([ transforms.Resize((img.shape[0], img.shape[1])), transforms.CenterCrop((w, h)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors ]) image_tensor = preprocess(image) img = image_tensor.unsqueeze(0).to(device) # Generate activation maps with torch.no_grad(): attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16) return attention_list # Streamlit UI st.title("PathDino - Compact ViT for Histopathology Image Analysis") st.write("Upload a histology image to view the attention maps.") # uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) uploaded_image = "images/HistRotate.png" uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"]) if uploaded_image is not None: # columns = st.columns(3) st.image(uploaded_image, caption="Uploaded Image", width=500) # Load the image and apply preprocessing uploaded_image = Image.open(uploaded_image).convert('RGB') attention_list = generate_activation_maps(uploaded_image) print(len(attention_list)) st.subheader(f"Attention Maps of the input image") columns = st.columns(2) columns2 = st.columns(2) columns3 = st.columns(2) # for index in range(6): for index, col in enumerate(columns): # Create a plot plt.plot(600, 600) # Remove x and y axis labels plt.xticks([]) # Hide x-axis ticks and labels plt.yticks([]) # Hide y-axis ticks and labels # Alternatively, if you only want to hide the labels and keep the ticks: plt.gca().axes.get_xaxis().set_visible(False) plt.gca().axes.get_yaxis().set_visible(False) print(type(attention_list[index])) print(attention_list[index].shape) plt.imshow(attention_list[index]) col.pyplot(plt) # col # st.image(plt, caption=f"Head-{index+1}", width=display_w) plt.close() for index, col in enumerate(columns2): index = index + 2 # Create a plot plt.plot(600, 600) # Remove x and y axis labels plt.xticks([]) # Hide x-axis ticks and labels plt.yticks([]) # Hide y-axis ticks and labels # Alternatively, if you only want to hide the labels and keep the ticks: plt.gca().axes.get_xaxis().set_visible(False) plt.gca().axes.get_yaxis().set_visible(False) plt.imshow(attention_list[index]) col.pyplot(plt) plt.close() for index, col in enumerate(columns3): index = index + 4 # Create a plot plt.plot(600, 600) # Remove x and y axis labels plt.xticks([]) # Hide x-axis ticks and labels plt.yticks([]) # Hide y-axis ticks and labels # Alternatively, if you only want to hide the labels and keep the ticks: plt.gca().axes.get_xaxis().set_visible(False) plt.gca().axes.get_yaxis().set_visible(False) plt.imshow(attention_list[index]) col.pyplot(plt) plt.close()