Saghir commited on
Commit
be2c585
1 Parent(s): a50d65f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import matplotlib.pyplot as plt
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import torch.nn as nn
7
+
8
+
9
+ from PathDino import get_pathDino_model
10
+
11
+ import os
12
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ # Load PathDino model and image transforms
17
+ model, image_transforms = get_pathDino_model("PathDino512.pth")
18
+
19
+
20
+ st.sidebar.markdown("### PathDino")
21
+ st.sidebar.markdown(
22
+ "PathDino is a lightweight histology transformer consisting of just five small vision transformer blocks. "
23
+ "PathDino is a customized ViT architecture, finely tuned to the nuances of histological images. It not only exhibits "
24
+ "superior performance but also effectively reduces susceptibility to overfitting, a common challenge in histology "
25
+ "image analysis.\n\n"
26
+ )
27
+
28
+ default_image_url_compare = "images/HistRotate.png"
29
+ st.sidebar.image(default_image_url_compare, caption='A 360 rotation augmentation for training models on histopathology images. Unlike training on natural images where the rotation may change the context of the visual data, rotating a histopathology patch does not change the context and it improves the learning process for better reliable embedding learning.', width=500)
30
+
31
+ default_image_url_compare = "images/FigPathDino_parameters_FLOPs_compare.png"
32
+ st.sidebar.image(default_image_url_compare, caption='PathDino Vs its counterparts. Number of Parameters (Millions) vs the patch-level retrieval with macro avg F-score of majority vote (MV@5) on CAMELYON16 dataset. The bubble size represents the FLOPs.', width=500)
33
+
34
+ default_image_url_compare = "images/ActivationMap.png"
35
+ st.sidebar.image(default_image_url_compare, caption='Attention Visualization. When visualizing attention patterns, our PathDino transformer outperforms HIPT-small and DinoSSLPath, despite being trained on a smaller dataset of 6 million TCGA patches. In contrast, DinoSSLPath and HIPT were trained on much larger datasets, with 19 million and 104 million TCGA patches, respectively.', width=500)
36
+
37
+
38
+
39
+ def visualize_attention_ViT(model, img, patch_size=16):
40
+ attention_list = []
41
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
42
+ w_featmap = img.shape[-2] // patch_size
43
+ h_featmap = img.shape[-1] // patch_size
44
+ attentions = model.get_last_selfattention(img.to(device))
45
+ nh = attentions.shape[1] # number of head
46
+ # we keep only the output patch attention
47
+ attentions = attentions[0, :, 0, 1:].reshape(nh, -1)
48
+ attentions = attentions.reshape(nh, w_featmap, h_featmap)
49
+ attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].detach().numpy()
50
+ for j in range(nh):
51
+ attention_list.append(attentions[j])
52
+ return attention_list
53
+
54
+ # Define the function to generate activation maps
55
+ def generate_activation_maps(image):
56
+ preprocess = transforms.Compose([
57
+ transforms.Resize((512, 512)),
58
+ transforms.CenterCrop(512),
59
+ transforms.ToTensor(),
60
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize the tensors
61
+ ])
62
+ image_tensor = preprocess(image)
63
+ img = image_tensor.unsqueeze(0).to(device)
64
+ # Generate activation maps
65
+ with torch.no_grad():
66
+ attention_list = visualize_attention_ViT(model=model, img=img, patch_size=16)
67
+ return attention_list
68
+
69
+ # Streamlit UI
70
+ st.title("PathDino - Compact ViT for Histolopathology Image Analysis")
71
+ st.write("Upload a histology image to view the activation maps.")
72
+
73
+ # uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
74
+ uploaded_image = "images/HistRotate.png"
75
+ uploaded_image = st.file_uploader("Upload an image", type=["jpg", "png", "jpeg"])
76
+
77
+ if uploaded_image is not None:
78
+ columns = st.columns(3)
79
+ columns[1].image(uploaded_image, caption="Uploaded Image", width=300)
80
+
81
+ # Load the image and apply preprocessing
82
+ uploaded_image = Image.open(uploaded_image).convert('RGB')
83
+ attention_list = generate_activation_maps(uploaded_image)
84
+ print(len(attention_list))
85
+ st.subheader(f"Attention Maps of the input image")
86
+ columns = st.columns(len(attention_list)//2)
87
+ columns2 = st.columns(len(attention_list)//2)
88
+ for index, col in enumerate(columns):
89
+ # Create a plot
90
+ plt.plot(512, 512)
91
+
92
+ # Remove x and y axis labels
93
+ plt.xticks([]) # Hide x-axis ticks and labels
94
+ plt.yticks([]) # Hide y-axis ticks and labels
95
+
96
+ # Alternatively, if you only want to hide the labels and keep the ticks:
97
+ plt.gca().axes.get_xaxis().set_visible(False)
98
+ plt.gca().axes.get_yaxis().set_visible(False)
99
+
100
+ plt.imshow(attention_list[index])
101
+ col.pyplot(plt)
102
+ plt.close()
103
+
104
+ for index, col in enumerate(columns2):
105
+
106
+ index = index + len(attention_list)//2
107
+ # Create a plot
108
+ plt.plot(512, 512)
109
+
110
+ # Remove x and y axis labels
111
+ plt.xticks([]) # Hide x-axis ticks and labels
112
+ plt.yticks([]) # Hide y-axis ticks and labels
113
+
114
+ # Alternatively, if you only want to hide the labels and keep the ticks:
115
+ plt.gca().axes.get_xaxis().set_visible(False)
116
+ plt.gca().axes.get_yaxis().set_visible(False)
117
+
118
+ plt.imshow(attention_list[index])
119
+ col.pyplot(plt)
120
+ plt.close()