dev(narugo): try fix bug
Browse files- .gitignore +3 -0
- tagger/model.py +2 -1
.gitignore
CHANGED
@@ -251,3 +251,6 @@ $RECYCLE.BIN/
|
|
251 |
|
252 |
# but keep examples
|
253 |
!*.example
|
|
|
|
|
|
|
|
251 |
|
252 |
# but keep examples
|
253 |
!*.example
|
254 |
+
/venv
|
255 |
+
.python-version
|
256 |
+
/.idea
|
tagger/model.py
CHANGED
@@ -125,9 +125,10 @@ def render_heatmap(
|
|
125 |
},
|
126 |
partial_rows: bool = True,
|
127 |
) -> tuple[list[Heatmap], Image.Image]:
|
128 |
-
hmap_dim = int(math.sqrt(pos_embed_dim))
|
129 |
|
130 |
image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
|
|
|
131 |
image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim)
|
132 |
image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
|
133 |
|
|
|
125 |
},
|
126 |
partial_rows: bool = True,
|
127 |
) -> tuple[list[Heatmap], Image.Image]:
|
128 |
+
# hmap_dim = int(math.sqrt(pos_embed_dim))
|
129 |
|
130 |
image_hmaps = gradients.mean(2, keepdim=True).mul(image_feats.unsqueeze(0)).squeeze()
|
131 |
+
hmap_dim = int(math.sqrt(image_hmaps.mean(-1).shape[0] / len(image_labels)))
|
132 |
image_hmaps = image_hmaps.mean(-1).reshape(len(image_labels), hmap_dim, hmap_dim)
|
133 |
image_hmaps = image_hmaps.max(torch.zeros_like(image_hmaps))
|
134 |
|