polejowska commited on
Commit
ab1dfe6
1 Parent(s): bc50c42

Update visualization.py

Browse files
Files changed (1) hide show
  1. visualization.py +0 -30
visualization.py CHANGED
@@ -21,17 +21,10 @@ def visualize_prediction(
21
  fig, ax = plt.subplots(figsize=(12, 12))
22
  ax.imshow(pil_img)
23
  if display_mask and mask is not None:
24
- # Convert the mask image to a numpy array
25
  mask_arr = np.asarray(mask)
26
-
27
- # Create a new mask with white objects and black background
28
  new_mask = np.zeros_like(mask_arr)
29
  new_mask[mask_arr > 0] = 255
30
-
31
- # Convert the numpy array back to a PIL Image
32
  new_mask = Image.fromarray(new_mask)
33
-
34
- # Display the new mask as a semi-transparent overlay
35
  ax.imshow(new_mask, alpha=0.5, cmap='viridis')
36
 
37
  colors = COLORS * 100
@@ -62,26 +55,20 @@ def visualize_prediction(
62
 
63
 
64
  def visualize_attention_map(pil_img, attention_map):
65
- # Get the attention map for the last layer
66
  attention_map = attention_map[-1].detach().cpu()
67
 
68
- # Get the number of heads
69
  n_heads = attention_map.shape[1]
70
 
71
- # Calculate the average attention weight for each head
72
  avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
73
 
74
- # Resize the attention map
75
  resized_attention_weight = F.interpolate(
76
  avg_attention_weight.unsqueeze(0).unsqueeze(0),
77
  size=pil_img.size[::-1],
78
  mode="bicubic",
79
  ).squeeze().numpy()
80
 
81
- # Create a grid of subplots
82
  fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
83
 
84
- # Loop through the subplots and plot the attention for each head
85
  for i, ax in enumerate(axes.flat):
86
  ax.imshow(pil_img)
87
  ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
@@ -91,20 +78,3 @@ def visualize_attention_map(pil_img, attention_map):
91
  plt.tight_layout()
92
 
93
  return fig2img(fig)
94
- # attention_map = attention_map[-1].detach().cpu()
95
- # avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
96
- # avg_attention_weight_resized = (
97
- # F.interpolate(
98
- # avg_attention_weight.unsqueeze(0).unsqueeze(0),
99
- # size=pil_img.size[::-1],
100
- # mode="bicubic",
101
- # )
102
- # .squeeze()
103
- # .numpy()
104
- # )
105
-
106
- # plt.imshow(pil_img)
107
- # plt.imshow(avg_attention_weight_resized, alpha=0.7, cmap="viridis")
108
- # plt.axis("off")
109
- # fig = plt.gcf()
110
- # return fig2img(fig)
 
21
  fig, ax = plt.subplots(figsize=(12, 12))
22
  ax.imshow(pil_img)
23
  if display_mask and mask is not None:
 
24
  mask_arr = np.asarray(mask)
 
 
25
  new_mask = np.zeros_like(mask_arr)
26
  new_mask[mask_arr > 0] = 255
 
 
27
  new_mask = Image.fromarray(new_mask)
 
 
28
  ax.imshow(new_mask, alpha=0.5, cmap='viridis')
29
 
30
  colors = COLORS * 100
 
55
 
56
 
57
  def visualize_attention_map(pil_img, attention_map):
 
58
  attention_map = attention_map[-1].detach().cpu()
59
 
 
60
  n_heads = attention_map.shape[1]
61
 
 
62
  avg_attention_weight = torch.mean(attention_map, dim=1).squeeze()
63
 
 
64
  resized_attention_weight = F.interpolate(
65
  avg_attention_weight.unsqueeze(0).unsqueeze(0),
66
  size=pil_img.size[::-1],
67
  mode="bicubic",
68
  ).squeeze().numpy()
69
 
 
70
  fig, axes = plt.subplots(nrows=1, ncols=n_heads, figsize=(n_heads*4, 4))
71
 
 
72
  for i, ax in enumerate(axes.flat):
73
  ax.imshow(pil_img)
74
  ax.imshow(attention_map[0,i,:,:].squeeze(), alpha=0.7, cmap="viridis")
 
78
  plt.tight_layout()
79
 
80
  return fig2img(fig)