narugo commited on
Commit
e0c2f10
1 Parent(s): 9f181f5

dev(narugo): try fix bug

Browse files
Files changed (2) hide show
  1. .gitignore +3 -0
  2. tagger/model.py +2 -1
.gitignore CHANGED
@@ -251,3 +251,6 @@ $RECYCLE.BIN/
251
 
252
  # but keep examples
253
  !*.example
 
 
 
 
251
 
252
  # but keep examples
253
  !*.example
254
+ /venv
255
+ .python-version
256
+ /.idea
tagger/model.py CHANGED
@@ -125,9 +125,10 @@ def render_heatmap(
125
  },
126
  partial_rows: bool = True,
127
  ) -> tuple[list[Heatmap], Image.Image]:
128
- hmap_dim = int(math.sqrt(pos_embed_dim))
129
 
130
  image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
 
131
  image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim)
132
  image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
133
 
 
125
  },
126
  partial_rows: bool = True,
127
  ) -> tuple[list[Heatmap], Image.Image]:
128
+ # hmap_dim = int(math.sqrt(pos_embed_dim))
129
 
130
  image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
131
+ hmap_dim = int(math.sqrt(image_hmaps.mean(-1).shape[0] / len(image_labels)))
132
  image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim)
133
  image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
134