Update moondream.py
Browse filesfix: per row generation for granularity.
- moondream.py +56 -55
moondream.py
CHANGED
|
@@ -951,42 +951,50 @@ class MoondreamModel(nn.Module):
|
|
| 951 |
temperature: float = 0.0,
|
| 952 |
top_p: float = 0.0,
|
| 953 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 954 |
tpl = self.config.tokenizer.templates["detect"]
|
| 955 |
if tpl is None:
|
| 956 |
raise NotImplementedError("Model does not support object detection.")
|
| 957 |
|
| 958 |
-
#
|
| 959 |
rows_ids, lens = [], []
|
| 960 |
for lab in labels:
|
| 961 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
| 962 |
t = torch.tensor(ids, device=self.device, dtype=torch.long)
|
| 963 |
rows_ids.append(t)
|
| 964 |
-
lens.append(t.numel())
|
| 965 |
|
| 966 |
B = len(rows_ids)
|
| 967 |
T = max(lens)
|
| 968 |
|
| 969 |
-
#
|
| 970 |
-
embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] #
|
| 971 |
padded = []
|
| 972 |
for e, L in zip(embs, lens):
|
| 973 |
pad = T - L
|
| 974 |
if pad > 0:
|
| 975 |
-
e = torch.cat([e[:
|
| 976 |
padded.append(e)
|
| 977 |
prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
|
| 978 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 979 |
|
| 980 |
-
#
|
| 981 |
base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
|
| 982 |
-
|
| 983 |
-
pos_ids
|
| 984 |
|
| 985 |
-
|
| 986 |
-
|
|
|
|
| 987 |
|
| 988 |
-
#
|
| 989 |
-
last_idx = torch.
|
| 990 |
|
| 991 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B,1,C)
|
| 992 |
last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B,V)
|
|
@@ -998,17 +1006,15 @@ class MoondreamModel(nn.Module):
|
|
| 998 |
probs = self._apply_top_p(probs, top_p)
|
| 999 |
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 1000 |
|
|
|
|
| 1001 |
pos_end = int(pos + T)
|
| 1002 |
return last_hidden, next_token, pos_end
|
| 1003 |
|
| 1004 |
-
|
| 1005 |
-
|
| 1006 |
-
|
| 1007 |
def _generate_points_batched(
|
| 1008 |
self,
|
| 1009 |
-
hidden, # (B,1,C)
|
| 1010 |
-
next_token, # (B,1)
|
| 1011 |
-
pos, # int
|
| 1012 |
include_size: bool = True,
|
| 1013 |
max_objects: int = 50,
|
| 1014 |
lora=None,
|
|
@@ -1020,15 +1026,14 @@ class MoondreamModel(nn.Module):
|
|
| 1020 |
eos_id = self.config.tokenizer.eos_id
|
| 1021 |
max_ctx = self.config.text.max_context
|
| 1022 |
|
| 1023 |
-
#
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
|
| 1029 |
|
| 1030 |
-
# helper: logits -> normalized [0..1] coordinate (soft-argmax for stability)
|
| 1031 |
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 1032 |
if logits.dim() == 3:
|
| 1033 |
logits = logits.squeeze(1) # (B, bins)
|
| 1034 |
if use_soft_argmax:
|
|
@@ -1043,31 +1048,30 @@ class MoondreamModel(nn.Module):
|
|
| 1043 |
|
| 1044 |
with torch.inference_mode():
|
| 1045 |
while alive.any() and (counts < max_objects).any():
|
| 1046 |
-
|
| 1047 |
|
| 1048 |
-
#
|
| 1049 |
-
x_logits = decode_coordinate(hidden, self.region)
|
| 1050 |
-
x_center = _argmax01(x_logits)
|
| 1051 |
-
x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
|
| 1052 |
|
| 1053 |
-
|
| 1054 |
-
|
| 1055 |
-
|
| 1056 |
-
pos_ids[alive_idx, 0] += 1
|
| 1057 |
|
| 1058 |
-
#
|
| 1059 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1060 |
-
y_center = _argmax01(y_logits)
|
| 1061 |
y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
|
| 1062 |
|
| 1063 |
-
|
| 1064 |
-
logits, hidden = self._decode_one_tok(y_emb,
|
| 1065 |
-
pos_ids[
|
| 1066 |
|
| 1067 |
if include_size:
|
| 1068 |
-
#
|
| 1069 |
-
size_ret = decode_size(hidden, self.region)
|
| 1070 |
-
w_logits, h_logits = self._norm_size_logits(size_ret, B)
|
| 1071 |
|
| 1072 |
if use_soft_argmax:
|
| 1073 |
bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
|
|
@@ -1083,8 +1087,7 @@ class MoondreamModel(nn.Module):
|
|
| 1083 |
|
| 1084 |
size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
|
| 1085 |
|
| 1086 |
-
|
| 1087 |
-
for i in alive_idx.tolist():
|
| 1088 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1089 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1090 |
yt = (y_center[i] - h[i] / 2).item()
|
|
@@ -1096,27 +1099,24 @@ class MoondreamModel(nn.Module):
|
|
| 1096 |
"y_max": max(0.0, min(1.0, yb)),
|
| 1097 |
})
|
| 1098 |
|
| 1099 |
-
|
| 1100 |
-
logits, hidden = self._decode_one_tok(size_emb,
|
| 1101 |
-
pos_ids[
|
| 1102 |
|
| 1103 |
next_tok = logits.argmax(dim=-1)
|
| 1104 |
if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1105 |
if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
|
| 1106 |
else:
|
| 1107 |
-
|
| 1108 |
-
for i in alive_idx.tolist():
|
| 1109 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1110 |
-
|
| 1111 |
-
logits, hidden = self._decode_one_tok(y_emb,
|
| 1112 |
-
pos_ids[
|
| 1113 |
next_tok = logits.argmax(dim=-1)
|
| 1114 |
if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1115 |
if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
|
| 1116 |
|
| 1117 |
-
counts[alive] += 1
|
| 1118 |
-
|
| 1119 |
-
# stop rows that hit eos OR reached quota
|
| 1120 |
finished_now = (next_tok == eos_id) | (counts >= max_objects)
|
| 1121 |
alive &= ~finished_now
|
| 1122 |
|
|
@@ -1124,6 +1124,7 @@ class MoondreamModel(nn.Module):
|
|
| 1124 |
|
| 1125 |
|
| 1126 |
|
|
|
|
| 1127 |
def detect_multi(self, image, objects, settings=None):
|
| 1128 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1129 |
raise NotImplementedError("Model does not support object detection.")
|
|
|
|
| 951 |
temperature: float = 0.0,
|
| 952 |
top_p: float = 0.0,
|
| 953 |
):
|
| 954 |
+
"""
|
| 955 |
+
Batch prefill for multiple detection labels.
|
| 956 |
+
|
| 957 |
+
- Right-pads each row with its *last* embedding so the true last token for
|
| 958 |
+
each row is still at index (len-1). We then take that per-row index.
|
| 959 |
+
- Advances KV to a common end position (pos + T) for all rows.
|
| 960 |
+
"""
|
| 961 |
tpl = self.config.tokenizer.templates["detect"]
|
| 962 |
if tpl is None:
|
| 963 |
raise NotImplementedError("Model does not support object detection.")
|
| 964 |
|
| 965 |
+
# Tokenize rows (variable lengths Li)
|
| 966 |
rows_ids, lens = [], []
|
| 967 |
for lab in labels:
|
| 968 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
| 969 |
t = torch.tensor(ids, device=self.device, dtype=torch.long)
|
| 970 |
rows_ids.append(t)
|
| 971 |
+
lens.append(int(t.numel()))
|
| 972 |
|
| 973 |
B = len(rows_ids)
|
| 974 |
T = max(lens)
|
| 975 |
|
| 976 |
+
# Embed, then RIGHT-pad by repeating the last real token embedding
|
| 977 |
+
embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # (Li, C)
|
| 978 |
padded = []
|
| 979 |
for e, L in zip(embs, lens):
|
| 980 |
pad = T - L
|
| 981 |
if pad > 0:
|
| 982 |
+
e = torch.cat([e, e[-1:].repeat(pad, 1)], dim=0) # (T, C)
|
| 983 |
padded.append(e)
|
| 984 |
prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
|
| 985 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 986 |
|
| 987 |
+
# Shared mask over the image prefix; broadcast to B
|
| 988 |
base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
|
| 989 |
+
attn_mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
|
| 990 |
+
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
| 991 |
|
| 992 |
+
# Prefill
|
| 993 |
+
hidden_BTC = self._prefill(prompt_emb, attn_mask, pos_ids, lora) # (B,T,C)
|
| 994 |
+
logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
|
| 995 |
|
| 996 |
+
# For each row, pick its *true* last token (Li-1), not a padded index
|
| 997 |
+
last_idx = torch.tensor([L - 1 for L in lens], device=self.device, dtype=torch.long) # (B,)
|
| 998 |
|
| 999 |
last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B,1,C)
|
| 1000 |
last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B,V)
|
|
|
|
| 1006 |
probs = self._apply_top_p(probs, top_p)
|
| 1007 |
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 1008 |
|
| 1009 |
+
# We advanced KV for T steps for everyone; decoding starts after that slot.
|
| 1010 |
pos_end = int(pos + T)
|
| 1011 |
return last_hidden, next_token, pos_end
|
| 1012 |
|
|
|
|
|
|
|
|
|
|
| 1013 |
def _generate_points_batched(
|
| 1014 |
self,
|
| 1015 |
+
hidden, # (B,1,C) last token hidden per row
|
| 1016 |
+
next_token, # (B,1)
|
| 1017 |
+
pos, # int: first free KV slot (after prefill)
|
| 1018 |
include_size: bool = True,
|
| 1019 |
max_objects: int = 50,
|
| 1020 |
lora=None,
|
|
|
|
| 1026 |
eos_id = self.config.tokenizer.eos_id
|
| 1027 |
max_ctx = self.config.text.max_context
|
| 1028 |
|
| 1029 |
+
# Per-row decoding mask & pos pointer
|
| 1030 |
+
attn = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool) # (B,1,1,K)
|
| 1031 |
+
if pos > 0:
|
| 1032 |
+
attn[:, :, :, :pos] = True
|
| 1033 |
+
pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
|
|
|
|
| 1034 |
|
|
|
|
| 1035 |
def _argmax01(logits: torch.Tensor) -> torch.Tensor:
|
| 1036 |
+
# returns normalized [0,1] bin position
|
| 1037 |
if logits.dim() == 3:
|
| 1038 |
logits = logits.squeeze(1) # (B, bins)
|
| 1039 |
if use_soft_argmax:
|
|
|
|
| 1048 |
|
| 1049 |
with torch.inference_mode():
|
| 1050 |
while alive.any() and (counts < max_objects).any():
|
| 1051 |
+
idx = alive.nonzero(as_tuple=False).squeeze(1)
|
| 1052 |
|
| 1053 |
+
# ---- x ----
|
| 1054 |
+
x_logits = decode_coordinate(hidden, self.region)
|
| 1055 |
+
x_center = _argmax01(x_logits)
|
| 1056 |
+
x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
|
| 1057 |
|
| 1058 |
+
attn[idx, 0, 0, pos_ids[idx, 0]] = True
|
| 1059 |
+
logits, hidden = self._decode_one_tok(x_emb, attn, pos_ids, lora)
|
| 1060 |
+
pos_ids[idx, 0] += 1
|
|
|
|
| 1061 |
|
| 1062 |
+
# ---- y ----
|
| 1063 |
y_logits = decode_coordinate(hidden, self.region)
|
| 1064 |
+
y_center = _argmax01(y_logits)
|
| 1065 |
y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
|
| 1066 |
|
| 1067 |
+
attn[idx, 0, 0, pos_ids[idx, 0]] = True
|
| 1068 |
+
logits, hidden = self._decode_one_tok(y_emb, attn, pos_ids, lora)
|
| 1069 |
+
pos_ids[idx, 0] += 1
|
| 1070 |
|
| 1071 |
if include_size:
|
| 1072 |
+
# ---- (w,h) ----
|
| 1073 |
+
size_ret = decode_size(hidden, self.region) # (...,2,bins)
|
| 1074 |
+
w_logits, h_logits = self._norm_size_logits(size_ret, B)
|
| 1075 |
|
| 1076 |
if use_soft_argmax:
|
| 1077 |
bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
|
|
|
|
| 1087 |
|
| 1088 |
size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
|
| 1089 |
|
| 1090 |
+
for i in idx.tolist():
|
|
|
|
| 1091 |
xl = (x_center[i] - w[i] / 2).item()
|
| 1092 |
xr = (x_center[i] + w[i] / 2).item()
|
| 1093 |
yt = (y_center[i] - h[i] / 2).item()
|
|
|
|
| 1099 |
"y_max": max(0.0, min(1.0, yb)),
|
| 1100 |
})
|
| 1101 |
|
| 1102 |
+
attn[idx, 0, 0, pos_ids[idx, 0]] = True
|
| 1103 |
+
logits, hidden = self._decode_one_tok(size_emb, attn, pos_ids, lora)
|
| 1104 |
+
pos_ids[idx, 0] += 1
|
| 1105 |
|
| 1106 |
next_tok = logits.argmax(dim=-1)
|
| 1107 |
if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1108 |
if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
|
| 1109 |
else:
|
| 1110 |
+
for i in idx.tolist():
|
|
|
|
| 1111 |
out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
|
| 1112 |
+
attn[idx, 0, 0, pos_ids[idx, 0]] = True
|
| 1113 |
+
logits, hidden = self._decode_one_tok(y_emb, attn, pos_ids, lora)
|
| 1114 |
+
pos_ids[idx, 0] += 1
|
| 1115 |
next_tok = logits.argmax(dim=-1)
|
| 1116 |
if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
|
| 1117 |
if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
|
| 1118 |
|
| 1119 |
+
counts[alive] += 1
|
|
|
|
|
|
|
| 1120 |
finished_now = (next_tok == eos_id) | (counts >= max_objects)
|
| 1121 |
alive &= ~finished_now
|
| 1122 |
|
|
|
|
| 1124 |
|
| 1125 |
|
| 1126 |
|
| 1127 |
+
|
| 1128 |
def detect_multi(self, image, objects, settings=None):
|
| 1129 |
if self.config.tokenizer.templates["detect"] is None:
|
| 1130 |
raise NotImplementedError("Model does not support object detection.")
|