lukemelas commited on
Commit
cfe4337
1 Parent(s): 95dc30b

Update app

Browse files
Files changed (1) hide show
  1. app.py +174 -66
app.py CHANGED
@@ -1,91 +1,199 @@
1
- import os, os.path
2
- from os.path import splitext
3
- import numpy as np
4
  import sys
 
 
 
5
  import matplotlib.pyplot as plt
 
 
6
  import torch
 
7
  import torchvision
8
- import wget
 
 
 
 
 
 
 
 
9
 
10
 
11
- destination_folder = "output"
12
- destination_for_weights = "weights"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- if os.path.exists(destination_for_weights):
15
- print("The weights are at", destination_for_weights)
16
- else:
17
- print("Creating folder at ", destination_for_weights, " to store weights")
18
- os.mkdir(destination_for_weights)
19
-
20
- segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
21
 
22
- if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
23
- print("Downloading Segmentation Weights, ", segmentationWeightsURL," to ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
24
- filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
25
- else:
26
- print("Segmentation Weights already present")
 
 
27
 
28
- torch.cuda.empty_cache()
29
 
30
- def collate_fn(x):
31
- x, f = zip(*x)
32
- i = list(map(lambda t: t.shape[1], x))
33
- x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
34
- return x, f, i
35
 
36
- model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
37
- model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
38
 
39
- print("loading weights from ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
 
 
40
 
 
 
 
 
 
 
41
  if torch.cuda.is_available():
42
- print("cuda is available, original weights")
43
  device = torch.device("cuda")
44
- model = torch.nn.DataParallel(model)
45
  model.to(device)
46
- checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
47
- model.load_state_dict(checkpoint['state_dict'])
48
  else:
49
- print("cuda is not available, cpu weights")
50
  device = torch.device("cpu")
51
- checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
52
- state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
53
- model.load_state_dict(state_dict_cpu)
54
 
55
- model.eval()
56
 
57
- def segment(inp):
58
- x = inp.transpose([2, 0, 1]) # channels-first
59
- x = np.expand_dims(x, axis=0) # adding a batch dimension
60
-
61
- mean = x.mean(axis=(0, 2, 3))
62
- std = x.std(axis=(0, 2, 3))
63
- x = x - mean.reshape(1, 3, 1, 1)
64
- x = x / std.reshape(1, 3, 1, 1)
65
-
66
- with torch.no_grad():
67
- x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
68
- output = model(x)
69
-
70
- y = output['out'].numpy()
71
- y = y.squeeze()
72
-
73
- out = y>0
74
-
75
- mask = inp.copy()
76
- mask[out] = np.array([0, 0, 255])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- return mask
 
 
 
 
 
79
 
80
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- i = gr.inputs.Image(shape=(112, 112))
83
- o = gr.outputs.Image()
 
 
 
 
 
 
 
84
 
85
- examples = [["img1.jpg"], ["img2.jpg"]]
86
- title = None #"Left Ventricle Segmentation"
87
- description = "This semantic segmentation model identifies the left ventricle in echocardiogram images."
88
- # videos. Accurate evaluation of the motion and size of the left ventricle is crucial for the assessment of cardiac function and ejection fraction. In this interface, the user inputs apical-4-chamber images from echocardiography videos and the model will output a prediction of the localization of the left ventricle in blue. This model was trained on the publicly released EchoNet-Dynamic dataset of 10k echocardiogram videos with 20k expert annotations of the left ventricle and published as part of ‘Video-based AI for beat-to-beat assessment of cardiac function’ by Ouyang et al. in Nature, 2020."
 
 
 
 
 
 
 
 
 
 
89
  thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
90
- gr.Interface(segment, i, o, examples=examples, allow_flagging=False, analytics_enabled=False,
91
- title=title, description=description, thumbnail=thumbnail).launch()
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
 
3
  import sys
4
+ from os.path import splitext
5
+
6
+ import gradio as gr
7
  import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import scipy.sparse
10
  import torch
11
+ import torch.nn.functional as F
12
  import torchvision
13
+ import torchvision.transforms.functional as TF
14
+ from gradio.inputs import Image as GradioInputImage
15
+ from gradio.outputs import Image as GradioOutputImage
16
+ from PIL import Image
17
+ from scipy.sparse.linalg import eigsh
18
+ from skimage.color import label2rgb
19
+ from torch.utils.hooks import RemovableHandle
20
+ from torchvision import transforms
21
+ from torchvision.utils import make_grid
22
 
23
 
24
+ def get_model(name: str):
25
+ if 'dino' in name:
26
+ model = torch.hub.load('facebookresearch/dino:main', name)
27
+ model.fc = torch.nn.Identity()
28
+ val_transform = get_transform(name)
29
+ patch_size = model.patch_embed.patch_size
30
+ num_heads = model.blocks[0].attn.num_heads
31
+ elif name in ['mocov3_vits16', 'mocov3_vitb16']:
32
+ model = torch.hub.load('facebookresearch/dino:main', name.replace('mocov3', 'dino'))
33
+ checkpoint_file, size_char = {
34
+ 'mocov3_vits16': ('vit-s-300ep-timm-format.pth', 's'),
35
+ 'mocov3_vitb16': ('vit-b-300ep-timm-format.pth', 'b'),
36
+ }[name]
37
+ url = f'https://dl.fbaipublicfiles.com/moco-v3/vit-{size_char}-300ep/vit-{size_char}-300ep.pth.tar'
38
+ checkpoint = torch.hub.load_state_dict_from_url(url)
39
+ model.load_state_dict(checkpoint['model'])
40
+ model.fc = torch.nn.Identity()
41
+ val_transform = get_transform(name)
42
+ patch_size = model.patch_embed.patch_size
43
+ num_heads = model.blocks[0].attn.num_heads
44
+ else:
45
+ raise ValueError(f'Unsupported model: {name}')
46
+ model = model.eval()
47
+ return model, val_transform, patch_size, num_heads
48
 
 
 
 
 
 
 
 
49
 
50
+ def get_transform(name: str):
51
+ if any(x in name for x in ('dino', 'mocov3', 'convnext', )):
52
+ normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
53
+ transform = transforms.Compose([transforms.ToTensor(), normalize])
54
+ else:
55
+ raise NotImplementedError()
56
+ return transform
57
 
 
58
 
59
+ def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12):
60
+ D = W.dot(np.ones(W.shape[1], W.dtype))
61
+ D[D < threshold] = 1.0 # Prevent division by zero.
62
+ D = scipy.sparse.diags(D)
63
+ return D
64
 
 
 
65
 
66
+ # Parameters
67
+ model_name = 'dino_vitb16' # TODOL Figure out how to make this user-editable
68
+ K = 5
69
 
70
+
71
+ # Load model
72
+ model, val_transform, patch_size, num_heads = get_model(model_name)
73
+
74
+
75
+ # GPU
76
  if torch.cuda.is_available():
77
+ print("CUDA is available, using GPU.")
78
  device = torch.device("cuda")
 
79
  model.to(device)
 
 
80
  else:
81
+ print("CUDA is not available, using CPU.")
82
  device = torch.device("cpu")
 
 
 
83
 
 
84
 
85
+ @torch.no_grad()
86
+ def segment(inp: Image):
87
+ # NOTE: The image is already resized to the desired size.
88
+
89
+ # Preprocess image
90
+ images: torch.Tensor = val_transform(inp)
91
+ images = images.unsqueeze(0).to(device)
92
+
93
+ # Add hook
94
+ which_block = -1
95
+ if 'dino' in model_name or 'mocov3' in model_name:
96
+ feat_out = {}
97
+ def hook_fn_forward_qkv(module, input, output):
98
+ feat_out["qkv"] = output
99
+ handle: RemovableHandle = model._modules["blocks"][which_block]._modules["attn"]._modules["qkv"].register_forward_hook(
100
+ hook_fn_forward_qkv
101
+ )
102
+ else:
103
+ raise ValueError(model_name)
104
+
105
+ # Reshape image
106
+ P = patch_size
107
+ B, C, H, W = images.shape
108
+ H_patch, W_patch = H // P, W // P
109
+ H_pad, W_pad = H_patch * P, W_patch * P
110
+ T = H_patch * W_patch + 1 # number of tokens, add 1 for [CLS]
111
+
112
+ # Crop image to be a multiple of the patch size
113
+ images = images[:, :, :H_pad, :W_pad]
114
+
115
+ # Extract features
116
+ if 'dino' in model_name or 'mocov3' in model_name:
117
+ model.get_intermediate_layers(images)[0].squeeze(0)
118
+ output_qkv = feat_out["qkv"].reshape(B, T, 3, num_heads, -1 // num_heads).permute(2, 0, 3, 1, 4)
119
+ feats = output_qkv[1].transpose(1, 2).reshape(B, T, -1)[:, 1:, :].squeeze(0)
120
+ else:
121
+ raise ValueError(model_name)
122
+
123
+ # Remove hook from the model
124
+ handle.remove()
125
+
126
+ # Normalize features
127
+ normalize = True
128
+ if normalize:
129
+ feats = F.normalize(feats, p=2, dim=-1)
130
+
131
+ # Compute affinity matrix
132
+ W_feat = (feats @ feats.T)
133
 
134
+ # Feature affinities
135
+ threshold_at_zero = True
136
+ if threshold_at_zero:
137
+ W_feat = (W_feat * (W_feat > 0))
138
+ W_feat = W_feat / W_feat.max() # NOTE: If features are normalized, this naturally does nothing
139
+ W_feat = W_feat.cpu().numpy()
140
 
141
+ # # NOTE: Here is where we would add the color information. For simplicity, we will not add it here.
142
+ # W_comb = W_feat + W_color * image_color_lambda # combination
143
+ # D_comb = np.array(get_diagonal(W_comb).todense()) # is dense or sparse faster? not sure, should check
144
+
145
+ # Diagonal
146
+ W_comb = W_feat
147
+ D_comb = np.array(get_diagonal(W_comb).todense()) # is dense or sparse faster? not sure, should check
148
+
149
+ # Compute eigenvectors
150
+ try:
151
+ eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=(K + 1), sigma=0, which='LM', M=D_comb)
152
+ except:
153
+ eigenvalues, eigenvectors = eigsh(D_comb - W_comb, k=(K + 1), which='SM', M=D_comb)
154
+ eigenvalues = torch.from_numpy(eigenvalues)
155
+ eigenvectors = torch.from_numpy(eigenvectors.T).float()
156
+
157
+ # Resolve sign ambiguity
158
+ for k in range(eigenvectors.shape[0]):
159
+ if 0.5 < torch.mean((eigenvectors[k] > 0).float()).item() < 1.0: # reverse segment
160
+ eigenvectors[k] = 0 - eigenvectors[k]
161
 
162
+ # Arrange eigenvectors into grid
163
+ output_image_grid = []
164
+ for i in range(1, K):
165
+ eigenvector = eigenvectors[i].reshape(1, 1, H_pad, W_pad)
166
+ eigenvector = F.interpolate(eigenvector, size=(H, W), mode='nearest') # slightly off, but for visualizations this is okay
167
+ # plt.imsave('./tmp.png', eigenvector.squeeze().numpy()) # save to a temporary location
168
+ # eigenvector = Image.open('./tmp.png').convert('RGB') # load back from our temporary location
169
+ output_image_grid.append(eigenvector)
170
+ img_tensor_grid = make_grid(output_image_grid, nrow=8, pad_value=1)
171
 
172
+ # Postprocess for Gradio
173
+ img_tensor_grid.numpy().squeeze()
174
+
175
+ return img_tensor_grid
176
+
177
+ # Placeholders
178
+ input_placeholders = GradioInputImage(shape=(256, 256), source="upload", tool="editor", type="pil")
179
+ output_placeholders = GradioOutputImage(type="numpy", label=f"Eigenvectors")
180
+ # alternatively: [GradioOutputImage(type="numpy", label=f"Eigenvector {i}") for i in range(K)]
181
+
182
+ # Metadata
183
+ examples = [["images/img1.jpg"], ["images/img2.jpg"]]
184
+ title = "Deep Spectral Segmentation"
185
+ description = "Deep spectral segmentation..."
186
  thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
187
+
188
+ # Gradio
189
+ gr.Interface(
190
+ segment,
191
+ input_placeholders,
192
+ output_placeholders,
193
+ examples=examples,
194
+ allow_flagging=False,
195
+ analytics_enabled=False,
196
+ title=title,
197
+ description=description,
198
+ thumbnail=thumbnail
199
+ ).launch()