File size: 5,447 Bytes
be2c585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import torch.nn as nn


from PathDino import get_pathDino_model

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load PathDino model and image transforms
model, image_transforms = get_pathDino_model("PathDino512.pth")


st.sidebar.markdown("### PathDino")
st.sidebar.markdown(
    "PathDino is a lightweight histology transformer consisting of just five small vision transformer blocks. "
    "PathDino is a customized ViT architecture, finely tuned to the nuances of histological images. It not only exhibits "
    "superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
    "image analysis.\n\n"
)

default_image_url_compare = "images/HistRotate.png"
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=500)

default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=500)

default_image_url_compare = "images/ActivationMap.png"
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=500)



def visualize_attention_ViT(model, img, patch_size=16):
        attention_list = []
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        w_featmap = img.shape[-2] // patch_size
        h_featmap = img.shape[-1] // patch_size
        attentions = model.get_last_selfattention(img.to(device))
        nh = attentions.shape[1] # number of head
        # we keep only the output patch attention
        attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
        attentions = attentions.reshape(nh, w_featmap, h_featmap)
        attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
        for j in range(nh):
            attention_list.append(attentions[j])
        return attention_list

# Define the function to generate activation maps
def generate_activation_maps(image):
    preprocess = transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.CenterCrop(512),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the tensors
    ])
    image_tensor = preprocess(image)
    img = image_tensor.unsqueeze(0).to(device)
    # Generate activation maps
    with torch.no_grad():
        attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
    return attention_list

# Streamlit UI
st.title("PathDino - Compact ViT for Histolopathology Image Analysis")
st.write("Upload a histology image to view the activation maps.")

# uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
uploaded_image = "images/HistRotate.png"
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])

if uploaded_image is not None:
    columns = st.columns(3)
    columns[1].image(uploaded_image, caption="Uploaded Image", width=300)

    # Load the image and apply preprocessing
    uploaded_image = Image.open(uploaded_image).convert('RGB')
    attention_list = generate_activation_maps(uploaded_image)
    print(len(attention_list))
    st.subheader(f"Attention Maps of the input image")
    columns = st.columns(len(attention_list)//2)
    columns2 = st.columns(len(attention_list)//2)
    for index, col in enumerate(columns):
        # Create a plot
        plt.plot(512, 512)

        # Remove x and y axis labels
        plt.xticks([])  # Hide x-axis ticks and labels
        plt.yticks([])  # Hide y-axis ticks and labels

        # Alternatively, if you only want to hide the labels and keep the ticks:
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.imshow(attention_list[index])
        col.pyplot(plt)
        plt.close()
        
    for index, col in enumerate(columns2):
        
        index = index + len(attention_list)//2
        # Create a plot
        plt.plot(512, 512)

        # Remove x and y axis labels
        plt.xticks([])  # Hide x-axis ticks and labels
        plt.yticks([])  # Hide y-axis ticks and labels

        # Alternatively, if you only want to hide the labels and keep the ticks:
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.imshow(attention_list[index])
        col.pyplot(plt)
        plt.close()