File size: 7,195 Bytes
be2c585
 
 
 
 
 
28de83f
be2c585
 
 
 
 
 
 
 
 
 
 
 
 
 
45532a7
3e2e594
be2c585
 
 
 
45532a7
be2c585
b5401fe
be2c585
 
b5401fe
be2c585
 
b5401fe
45532a7
 
 
 
 
 
162e915
 
 
 
 
 
 
 
45532a7
 
8176d59
45532a7
be2c585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28de83f
 
 
 
 
be8436c
 
28de83f
be2c585
28de83f
be8436c
28de83f
 
 
be2c585
 
 
 
 
 
 
 
0da0f9a
1c4129f
be2c585
 
 
 
 
 
be8436c
 
be2c585
 
 
 
be8436c
be2c585
 
28de83f
 
 
be8436c
be2c585
 
be8436c
be2c585
 
 
 
 
 
 
 
be8436c
 
 
be2c585
 
 
be8436c
 
be2c585
 
 
 
28de83f
 
be8436c
28de83f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be2c585
be8436c
be2c585
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import streamlit as st
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import numpy as np

from PathDino import get_pathDino_model

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load PathDino model and image transforms
model, image_transforms = get_pathDino_model("PathDino512.pth")


st.sidebar.markdown("### PathDino")
st.sidebar.markdown(
    "PathDino is a lightweight histopathology transformer consisting of just five small vision transformer blocks. "
    "PathDino is a customized ViT architecture, finely tuned to the nuances of histology images. It not only exhibits "
    "superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
    "image analysis.\n\n"
)


default_image_url_compare = "images/HistRotate.png"
st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=300)

default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=300)

default_image_url_compare = "images/ActivationMap.png"
st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=300)


st.sidebar.markdown("### Citation")
# Create a code block for citations
st.sidebar.markdown("""
```markdown
@article{alfasly2023rotationagnostic,
      title={Rotation-Agnostic Image Representation Learning for Digital Pathology}, 
      author={Saghir Alfasly and Abubakr Shafique and Peyman Nejat and Jibran Khan and Areej Alsaafin and Ghazal Alabtah and H. R. Tizhoosh},
      year={2023},
      eprint={2311.08359},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}""")

st.sidebar.markdown("\n\n")
st.sidebar.markdown("KIMIA Lab, Department of Artificial Intelligence and Informatics, \n Mayo Clinic, \n Rochester, MN, USA")




def visualize_attention_ViT(model, img, patch_size=16):
        attention_list = []
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
        w_featmap = img.shape[-2] // patch_size
        h_featmap = img.shape[-1] // patch_size
        attentions = model.get_last_selfattention(img.to(device))
        nh = attentions.shape[1] # number of head
        # we keep only the output patch attention
        attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
        attentions = attentions.reshape(nh, w_featmap, h_featmap)
        attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
        for j in range(nh):
            attention_list.append(attentions[j])
        return attention_list

# Define the function to generate activation maps
def generate_activation_maps(image, patch_size=16):
    # Convert the image to a NumPy array
    img = np.array(image)
    # make the image divisible by the patch size
    w, h = img.shape[1] - img.shape[0] % patch_size, img.shape[1] - img.shape[1] % patch_size
    print("w, h:", w, h)
    # min_size = min(w, h)
    print("Image shape:", img.shape)
    preprocess = transforms.Compose([
            transforms.Resize((img.shape[0], img.shape[1])),
            transforms.CenterCrop((w, h)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the tensors
        ])
    image_tensor = preprocess(image)
    img = image_tensor.unsqueeze(0).to(device)
    # Generate activation maps
    with torch.no_grad():
        attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
    return attention_list

# Streamlit UI
st.title("PathDino - Compact ViT for Histopathology Image Analysis")
st.write("Upload a histology image to view the attention maps.")

# uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
uploaded_image = "images/HistRotate.png"
uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])

if uploaded_image is not None:
    # columns = st.columns(3)
    st.image(uploaded_image, caption="Uploaded Image", width=500)

    # Load the image and apply preprocessing
    uploaded_image = Image.open(uploaded_image).convert('RGB')
    attention_list = generate_activation_maps(uploaded_image)
    
    print(len(attention_list))
    st.subheader(f"Attention Maps of the input image")
    columns = st.columns(2)
    columns2 = st.columns(2)
    columns3 = st.columns(2)
    # for index in range(6):
    for index, col in enumerate(columns):
        # Create a plot
        plt.plot(600, 600)

        # Remove x and y axis labels
        plt.xticks([])  # Hide x-axis ticks and labels
        plt.yticks([])  # Hide y-axis ticks and labels

        # Alternatively, if you only want to hide the labels and keep the ticks:
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)
        
        print(type(attention_list[index]))
        print(attention_list[index].shape)

        plt.imshow(attention_list[index])
        col.pyplot(plt)
        # col
        # st.image(plt, caption=f"Head-{index+1}", width=display_w)
        plt.close()
        
    for index, col in enumerate(columns2):
        
        index = index + 2
        # Create a plot
        plt.plot(600, 600)

        # Remove x and y axis labels
        plt.xticks([])  # Hide x-axis ticks and labels
        plt.yticks([])  # Hide y-axis ticks and labels

        # Alternatively, if you only want to hide the labels and keep the ticks:
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.imshow(attention_list[index])
        col.pyplot(plt)
        plt.close()
        
    for index, col in enumerate(columns3):
        
        index = index + 4
        # Create a plot
        plt.plot(600, 600)

        # Remove x and y axis labels
        plt.xticks([])  # Hide x-axis ticks and labels
        plt.yticks([])  # Hide y-axis ticks and labels

        # Alternatively, if you only want to hide the labels and keep the ticks:
        plt.gca().axes.get_xaxis().set_visible(False)
        plt.gca().axes.get_yaxis().set_visible(False)

        plt.imshow(attention_list[index])
        col.pyplot(plt)
        plt.close()