Update region.py
Browse filesfix: return shape for decode.
region.py
CHANGED
|
@@ -71,19 +71,11 @@ def encode_size(size: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
|
| 71 |
return w.size_encoder(fourier_features(size, w.size_features))
|
| 72 |
|
| 73 |
|
| 74 |
-
# region.py
|
| 75 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 76 |
-
|
| 77 |
-
|
|
|
|
| 78 |
|
| 79 |
-
Input (hidden_state): (..., C)
|
| 80 |
-
Output: (..., 2, bins) # keeps all leading dims intact
|
| 81 |
-
"""
|
| 82 |
-
x = mlp(hidden_state, w.size_decoder) # (..., size_out_dim)
|
| 83 |
-
last = x.shape[-1]
|
| 84 |
-
if last % 2 != 0:
|
| 85 |
-
raise RuntimeError(f"size_out_dim must be even, got {last}")
|
| 86 |
-
return x.view(*x.shape[:-1], 2, last // 2) # (..., 2, bins)
|
| 87 |
|
| 88 |
|
| 89 |
|
|
|
|
| 71 |
return w.size_encoder(fourier_features(size, w.size_features))
|
| 72 |
|
| 73 |
|
|
|
|
| 74 |
def decode_size(hidden_state: torch.Tensor, w: nn.Module) -> torch.Tensor:
|
| 75 |
+
# Original API expected by moondream.py: shape (2, C) when called on the last hidden state
|
| 76 |
+
x = mlp(hidden_state, w.size_decoder) # (..., 2*C)
|
| 77 |
+
return x.view(2, -1)
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
|