HV-Khurdula commited on
Commit
5f6e97e
·
verified ·
1 Parent(s): 13973a3

Update moondream.py

Browse files

fix: per row generation for granularity.

Files changed (1) hide show
  1. moondream.py +56 -55
moondream.py CHANGED
@@ -951,42 +951,50 @@ class MoondreamModel(nn.Module):
951
  temperature: float = 0.0,
952
  top_p: float = 0.0,
953
  ):
 
 
 
 
 
 
 
954
  tpl = self.config.tokenizer.templates["detect"]
955
  if tpl is None:
956
  raise NotImplementedError("Model does not support object detection.")
957
 
958
- # 1) Tokenize each label (variable lengths Li)
959
  rows_ids, lens = [], []
960
  for lab in labels:
961
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
962
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
963
  rows_ids.append(t)
964
- lens.append(t.numel())
965
 
966
  B = len(rows_ids)
967
  T = max(lens)
968
 
969
- # 2) Embed and LEFT-pad each row with its first token embedding
970
- embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list[(Li, C)]
971
  padded = []
972
  for e, L in zip(embs, lens):
973
  pad = T - L
974
  if pad > 0:
975
- e = torch.cat([e[:1].repeat(pad, 1), e], dim=0) # (T, C)
976
  padded.append(e)
977
  prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
978
  torch._dynamo.mark_dynamic(prompt_emb, 1)
979
 
980
- # 3) Prefill over the shared image prefix [pos : pos+T)
981
  base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
982
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
983
- pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
984
 
985
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
986
- logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
 
987
 
988
- # *** IMPORTANT: take the tail position for every row ***
989
- last_idx = torch.full((B,), T - 1, device=self.device, dtype=torch.long)
990
 
991
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B,1,C)
992
  last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B,V)
@@ -998,17 +1006,15 @@ class MoondreamModel(nn.Module):
998
  probs = self._apply_top_p(probs, top_p)
999
  next_token = torch.multinomial(probs, num_samples=1) # (B,1)
1000
 
 
1001
  pos_end = int(pos + T)
1002
  return last_hidden, next_token, pos_end
1003
 
1004
-
1005
-
1006
-
1007
  def _generate_points_batched(
1008
  self,
1009
- hidden, # (B,1,C) - last token hidden state per row
1010
- next_token, # (B,1) - unused for greedy loop; kept for API
1011
- pos, # int - first free position in cache
1012
  include_size: bool = True,
1013
  max_objects: int = 50,
1014
  lora=None,
@@ -1020,15 +1026,14 @@ class MoondreamModel(nn.Module):
1020
  eos_id = self.config.tokenizer.eos_id
1021
  max_ctx = self.config.text.max_context
1022
 
1023
- # 4D mask: (B,1,1,K); we advance per-row
1024
- mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1025
- p0 = int(pos)
1026
- if p0 > 0:
1027
- mask[:, :, :, :p0] = True
1028
- pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
1029
 
1030
- # helper: logits -> normalized [0..1] coordinate (soft-argmax for stability)
1031
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
 
1032
  if logits.dim() == 3:
1033
  logits = logits.squeeze(1) # (B, bins)
1034
  if use_soft_argmax:
@@ -1043,31 +1048,30 @@ class MoondreamModel(nn.Module):
1043
 
1044
  with torch.inference_mode():
1045
  while alive.any() and (counts < max_objects).any():
1046
- alive_idx = alive.nonzero(as_tuple=False).squeeze(1)
1047
 
1048
- # ---------- x ----------
1049
- x_logits = decode_coordinate(hidden, self.region) # (B,1,bins) or (B,bins)
1050
- x_center = _argmax01(x_logits) # (B,)
1051
- x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1) # (B,1,C)
1052
 
1053
- # advance one token for each alive row (per-row column)
1054
- mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1055
- logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1056
- pos_ids[alive_idx, 0] += 1
1057
 
1058
- # ---------- y ----------
1059
  y_logits = decode_coordinate(hidden, self.region)
1060
- y_center = _argmax01(y_logits) # (B,)
1061
  y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1062
 
1063
- mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1064
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1065
- pos_ids[alive_idx, 0] += 1
1066
 
1067
  if include_size:
1068
- # ---------- size (w,h) ----------
1069
- size_ret = decode_size(hidden, self.region) # (...,2,bins)
1070
- w_logits, h_logits = self._norm_size_logits(size_ret, B) # each (B,bins)
1071
 
1072
  if use_soft_argmax:
1073
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
@@ -1083,8 +1087,7 @@ class MoondreamModel(nn.Module):
1083
 
1084
  size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
1085
 
1086
- # write outputs only for alive rows
1087
- for i in alive_idx.tolist():
1088
  xl = (x_center[i] - w[i] / 2).item()
1089
  xr = (x_center[i] + w[i] / 2).item()
1090
  yt = (y_center[i] - h[i] / 2).item()
@@ -1096,27 +1099,24 @@ class MoondreamModel(nn.Module):
1096
  "y_max": max(0.0, min(1.0, yb)),
1097
  })
1098
 
1099
- mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1100
- logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1101
- pos_ids[alive_idx, 0] += 1
1102
 
1103
  next_tok = logits.argmax(dim=-1)
1104
  if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
1105
  if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
1106
  else:
1107
- # points only
1108
- for i in alive_idx.tolist():
1109
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1110
- mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1111
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1112
- pos_ids[alive_idx, 0] += 1
1113
  next_tok = logits.argmax(dim=-1)
1114
  if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
1115
  if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
1116
 
1117
- counts[alive] += 1 # we produced one object/point for each alive row
1118
-
1119
- # stop rows that hit eos OR reached quota
1120
  finished_now = (next_tok == eos_id) | (counts >= max_objects)
1121
  alive &= ~finished_now
1122
 
@@ -1124,6 +1124,7 @@ class MoondreamModel(nn.Module):
1124
 
1125
 
1126
 
 
1127
  def detect_multi(self, image, objects, settings=None):
1128
  if self.config.tokenizer.templates["detect"] is None:
1129
  raise NotImplementedError("Model does not support object detection.")
 
951
  temperature: float = 0.0,
952
  top_p: float = 0.0,
953
  ):
954
+ """
955
+ Batch prefill for multiple detection labels.
956
+
957
+ - Right-pads each row with its *last* embedding so the true last token for
958
+ each row is still at index (len-1). We then take that per-row index.
959
+ - Advances KV to a common end position (pos + T) for all rows.
960
+ """
961
  tpl = self.config.tokenizer.templates["detect"]
962
  if tpl is None:
963
  raise NotImplementedError("Model does not support object detection.")
964
 
965
+ # Tokenize rows (variable lengths Li)
966
  rows_ids, lens = [], []
967
  for lab in labels:
968
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
969
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
970
  rows_ids.append(t)
971
+ lens.append(int(t.numel()))
972
 
973
  B = len(rows_ids)
974
  T = max(lens)
975
 
976
+ # Embed, then RIGHT-pad by repeating the last real token embedding
977
+ embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # (Li, C)
978
  padded = []
979
  for e, L in zip(embs, lens):
980
  pad = T - L
981
  if pad > 0:
982
+ e = torch.cat([e, e[-1:].repeat(pad, 1)], dim=0) # (T, C)
983
  padded.append(e)
984
  prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
985
  torch._dynamo.mark_dynamic(prompt_emb, 1)
986
 
987
+ # Shared mask over the image prefix; broadcast to B
988
  base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
989
+ attn_mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
990
+ pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
991
 
992
+ # Prefill
993
+ hidden_BTC = self._prefill(prompt_emb, attn_mask, pos_ids, lora) # (B,T,C)
994
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
995
 
996
+ # For each row, pick its *true* last token (Li-1), not a padded index
997
+ last_idx = torch.tensor([L - 1 for L in lens], device=self.device, dtype=torch.long) # (B,)
998
 
999
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B,1,C)
1000
  last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B,V)
 
1006
  probs = self._apply_top_p(probs, top_p)
1007
  next_token = torch.multinomial(probs, num_samples=1) # (B,1)
1008
 
1009
+ # We advanced KV for T steps for everyone; decoding starts after that slot.
1010
  pos_end = int(pos + T)
1011
  return last_hidden, next_token, pos_end
1012
 
 
 
 
1013
  def _generate_points_batched(
1014
  self,
1015
+ hidden, # (B,1,C) last token hidden per row
1016
+ next_token, # (B,1)
1017
+ pos, # int: first free KV slot (after prefill)
1018
  include_size: bool = True,
1019
  max_objects: int = 50,
1020
  lora=None,
 
1026
  eos_id = self.config.tokenizer.eos_id
1027
  max_ctx = self.config.text.max_context
1028
 
1029
+ # Per-row decoding mask & pos pointer
1030
+ attn = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool) # (B,1,1,K)
1031
+ if pos > 0:
1032
+ attn[:, :, :, :pos] = True
1033
+ pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
 
1034
 
 
1035
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
1036
+ # returns normalized [0,1] bin position
1037
  if logits.dim() == 3:
1038
  logits = logits.squeeze(1) # (B, bins)
1039
  if use_soft_argmax:
 
1048
 
1049
  with torch.inference_mode():
1050
  while alive.any() and (counts < max_objects).any():
1051
+ idx = alive.nonzero(as_tuple=False).squeeze(1)
1052
 
1053
+ # ---- x ----
1054
+ x_logits = decode_coordinate(hidden, self.region)
1055
+ x_center = _argmax01(x_logits)
1056
+ x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1057
 
1058
+ attn[idx, 0, 0, pos_ids[idx, 0]] = True
1059
+ logits, hidden = self._decode_one_tok(x_emb, attn, pos_ids, lora)
1060
+ pos_ids[idx, 0] += 1
 
1061
 
1062
+ # ---- y ----
1063
  y_logits = decode_coordinate(hidden, self.region)
1064
+ y_center = _argmax01(y_logits)
1065
  y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1066
 
1067
+ attn[idx, 0, 0, pos_ids[idx, 0]] = True
1068
+ logits, hidden = self._decode_one_tok(y_emb, attn, pos_ids, lora)
1069
+ pos_ids[idx, 0] += 1
1070
 
1071
  if include_size:
1072
+ # ---- (w,h) ----
1073
+ size_ret = decode_size(hidden, self.region) # (...,2,bins)
1074
+ w_logits, h_logits = self._norm_size_logits(size_ret, B)
1075
 
1076
  if use_soft_argmax:
1077
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
 
1087
 
1088
  size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
1089
 
1090
+ for i in idx.tolist():
 
1091
  xl = (x_center[i] - w[i] / 2).item()
1092
  xr = (x_center[i] + w[i] / 2).item()
1093
  yt = (y_center[i] - h[i] / 2).item()
 
1099
  "y_max": max(0.0, min(1.0, yb)),
1100
  })
1101
 
1102
+ attn[idx, 0, 0, pos_ids[idx, 0]] = True
1103
+ logits, hidden = self._decode_one_tok(size_emb, attn, pos_ids, lora)
1104
+ pos_ids[idx, 0] += 1
1105
 
1106
  next_tok = logits.argmax(dim=-1)
1107
  if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
1108
  if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
1109
  else:
1110
+ for i in idx.tolist():
 
1111
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1112
+ attn[idx, 0, 0, pos_ids[idx, 0]] = True
1113
+ logits, hidden = self._decode_one_tok(y_emb, attn, pos_ids, lora)
1114
+ pos_ids[idx, 0] += 1
1115
  next_tok = logits.argmax(dim=-1)
1116
  if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
1117
  if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
1118
 
1119
+ counts[alive] += 1
 
 
1120
  finished_now = (next_tok == eos_id) | (counts >= max_objects)
1121
  alive &= ~finished_now
1122
 
 
1124
 
1125
 
1126
 
1127
+
1128
  def detect_multi(self, image, objects, settings=None):
1129
  if self.config.tokenizer.templates["detect"] is None:
1130
  raise NotImplementedError("Model does not support object detection.")