Spaces:
Runtime error
Runtime error
polejowska
commited on
Commit
•
ab1dfe6
1
Parent(s):
bc50c42
Update visualization.py
Browse files- visualization.py +0 -30
visualization.py
CHANGED
@@ -21,17 +21,10 @@ def visualize_prediction(
|
|
21 |
fig, ax = plt.subplots(figsize=(12, 12))
|
22 |
ax.imshow(pil_img)
|
23 |
if display_mask and mask is not None:
|
24 |
-
# Convert the mask image to a numpy array
|
25 |
mask_arr = np.asarray(mask)
|
26 |
-
|
27 |
-
# Create a new mask with white objects and black background
|
28 |
new_mask = np.zeros_like(mask_arr)
|
29 |
new_mask[mask_arr > 0] = 255
|
30 |
-
|
31 |
-
# Convert the numpy array back to a PIL Image
|
32 |
new_mask = Image.fromarray(new_mask)
|
33 |
-
|
34 |
-
# Display the new mask as a semi-transparent overlay
|
35 |
ax.imshow(new_mask, alpha=0.5, cmap='viridis')
|
36 |
|
37 |
colors = COLORS * 100
|
@@ -62,26 +55,20 @@ def visualize_prediction(
|
|
62 |
|
63 |
|
64 |
def visualize_attention_map(pil_img, attention_map):
|
65 |
-
# Get the attention map for the last layer
|
66 |
attention_map = attention_map[-1].detach().cpu()
|
67 |
|
68 |
-
# Get the number of heads
|
69 |
n_heads = attention_map.shape[1]
|
70 |
|
71 |
-
# Calculate the average attention weight for each head
|
72 |
avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
|
73 |
|
74 |
-
# Resize the attention map
|
75 |
resized_attention_weight = F.interpolate(
|
76 |
avg_attention_weight.unsqueeze(0).unsqueeze(0),
|
77 |
size=pil_img.size[::-1],
|
78 |
mode="bicubic",
|
79 |
).squeeze().numpy()
|
80 |
|
81 |
-
# Create a grid of subplots
|
82 |
fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
|
83 |
|
84 |
-
# Loop through the subplots and plot the attention for each head
|
85 |
for i, ax in enumerate(axes.flat):
|
86 |
ax.imshow(pil_img)
|
87 |
ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
|
@@ -91,20 +78,3 @@ def visualize_attention_map(pil_img, attention_map):
|
|
91 |
plt.tight_layout()
|
92 |
|
93 |
return fig2img(fig)
|
94 |
-
# attention_map = attention_map[-1].detach().cpu()
|
95 |
-
# avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
|
96 |
-
# avg_attention_weight_resized = (
|
97 |
-
# F.interpolate(
|
98 |
-
# avg_attention_weight.unsqueeze(0).unsqueeze(0),
|
99 |
-
# size=pil_img.size[::-1],
|
100 |
-
# mode="bicubic",
|
101 |
-
# )
|
102 |
-
# .squeeze()
|
103 |
-
# .numpy()
|
104 |
-
# )
|
105 |
-
|
106 |
-
# plt.imshow(pil_img)
|
107 |
-
# plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis")
|
108 |
-
# plt.axis("off")
|
109 |
-
# fig = plt.gcf()
|
110 |
-
# return fig2img(fig)
|
|
|
21 |
fig, ax = plt.subplots(figsize=(12, 12))
|
22 |
ax.imshow(pil_img)
|
23 |
if display_mask and mask is not None:
|
|
|
24 |
mask_arr = np.asarray(mask)
|
|
|
|
|
25 |
new_mask = np.zeros_like(mask_arr)
|
26 |
new_mask[mask_arr > 0] = 255
|
|
|
|
|
27 |
new_mask = Image.fromarray(new_mask)
|
|
|
|
|
28 |
ax.imshow(new_mask, alpha=0.5, cmap='viridis')
|
29 |
|
30 |
colors = COLORS * 100
|
|
|
55 |
|
56 |
|
57 |
def visualize_attention_map(pil_img, attention_map):
|
|
|
58 |
attention_map = attention_map[-1].detach().cpu()
|
59 |
|
|
|
60 |
n_heads = attention_map.shape[1]
|
61 |
|
|
|
62 |
avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
|
63 |
|
|
|
64 |
resized_attention_weight = F.interpolate(
|
65 |
avg_attention_weight.unsqueeze(0).unsqueeze(0),
|
66 |
size=pil_img.size[::-1],
|
67 |
mode="bicubic",
|
68 |
).squeeze().numpy()
|
69 |
|
|
|
70 |
fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
|
71 |
|
|
|
72 |
for i, ax in enumerate(axes.flat):
|
73 |
ax.imshow(pil_img)
|
74 |
ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
|
|
|
78 |
plt.tight_layout()
|
79 |
|
80 |
return fig2img(fig)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|