HV-Khurdula commited on
Commit
c366261
·
verified ·
1 Parent(s): 9a7633c

Update moondream.py

Browse files

fix: batched generation

Files changed (1) hide show
  1. moondream.py +59 -59
moondream.py CHANGED
@@ -943,7 +943,14 @@ class MoondreamModel(nn.Module):
943
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
944
 
945
 
946
- def _prefill_prompt_batched(self, labels, pos: int, lora=None, temperature: float = 0.0, top_p: float = 0.0):
 
 
 
 
 
 
 
947
  tpl = self.config.tokenizer.templates["detect"]
948
  if tpl is None:
949
  raise NotImplementedError("Model does not support object detection.")
@@ -953,43 +960,49 @@ class MoondreamModel(nn.Module):
953
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
954
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
955
  rows.append(t); lens.append(t.numel())
 
956
  B, T = len(rows), max(lens)
957
  eos = self.config.tokenizer.eos_id
958
 
 
959
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
960
  for i, ids in enumerate(rows):
961
  prompt_ids[i, : ids.numel()] = ids
962
 
963
- prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
964
  torch._dynamo.mark_dynamic(prompt_emb, 1)
965
 
966
- base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
967
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
968
 
969
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
970
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
971
- logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
972
 
973
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
 
974
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
975
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
976
 
977
  if temperature == 0.0:
978
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
979
  else:
980
  probs = torch.softmax(last_logits / temperature, dim=-1)
981
  probs = self._apply_top_p(probs, top_p)
982
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
 
 
 
983
 
984
- pos_end = int(pos + T) # shared next-free position
985
- return last_hidden, next_token, pos_end
986
 
987
 
988
  def _generate_points_batched(
989
  self,
990
- hidden, # (B,1,C)
991
- next_token, # (B,1) (unused for greedy)
992
- pos, # int (start position in cache)
993
  include_size: bool = True,
994
  max_objects: int = 50,
995
  lora=None,
@@ -999,18 +1012,17 @@ class MoondreamModel(nn.Module):
999
  device = self.device
1000
  out = [[] for _ in range(B)]
1001
  eos_id = self.config.tokenizer.eos_id
 
1002
  max_ctx = self.config.text.max_context
1003
 
1004
- # 4-D mask: (B, 1, q_len=1, kv_len)
1005
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1006
- p0 = int(pos)
1007
- if p0 > 0:
1008
- mask[:, :, :, :p0] = True
 
 
1009
 
1010
- # per-row position ids (B,1)
1011
- pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
1012
-
1013
- # helper: (B, bins) -> (B,) in [0,1]
1014
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
1015
  if use_soft_argmax:
1016
  probs = torch.softmax(logits, dim=-1)
@@ -1019,14 +1031,11 @@ class MoondreamModel(nn.Module):
1019
  idx = logits.argmax(dim=-1).to(torch.float32)
1020
  return idx / float(logits.size(-1) - 1)
1021
 
1022
- # advance-one-step for a subset of rows (alive only)
1023
  def _advance_rows(row_mask: torch.Tensor):
1024
  idx = row_mask.nonzero(as_tuple=False).flatten()
1025
- # set each row's next KV column true
1026
  for i in idx.tolist():
1027
  col = int(pos_ids[i, 0].item())
1028
  mask[i, 0, 0, col] = True
1029
- # decoder step (all rows run, but only alive rows’ pos_ids move)
1030
  return idx
1031
 
1032
  alive = torch.ones(B, dtype=torch.bool, device=device)
@@ -1034,39 +1043,29 @@ class MoondreamModel(nn.Module):
1034
 
1035
  with torch.inference_mode():
1036
  while alive.any() and (counts < max_objects).any():
1037
-
1038
- # ---------------- x ----------------
1039
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1040
- if x_logits.dim() == 3:
1041
- x_logits = x_logits.squeeze(1) # -> (B,1024)
1042
- x_center = _argmax01(x_logits) # (B,)
1043
- x_emb = encode_coordinate(
1044
- x_center.to(dtype=x_logits.dtype).unsqueeze(-1), # (B,1)
1045
- self.region
1046
- ).unsqueeze(1) # (B,1,C)
1047
 
1048
  idx = _advance_rows(alive)
1049
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1050
  pos_ids[idx, 0] += 1
1051
 
1052
- # ---------------- y ----------------
1053
  y_logits = decode_coordinate(hidden, self.region)
1054
- if y_logits.dim() == 3:
1055
- y_logits = y_logits.squeeze(1)
1056
  y_center = _argmax01(y_logits)
1057
- y_emb = encode_coordinate(
1058
- y_center.to(dtype=y_logits.dtype).unsqueeze(-1),
1059
- self.region
1060
- ).unsqueeze(1)
1061
 
1062
  idx = _advance_rows(alive)
1063
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1064
  pos_ids[idx, 0] += 1
1065
 
1066
  if include_size:
1067
- # ------------- size (w,h) -------------
1068
  size_ret = decode_size(hidden, self.region)
1069
- w_logits, h_logits = self._norm_size_logits(size_ret, B) # each (B,C)
1070
 
1071
  if use_soft_argmax:
1072
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
@@ -1076,16 +1075,12 @@ class MoondreamModel(nn.Module):
1076
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1077
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1078
 
1079
- # inverse log-scale mapping used by MD2
1080
- w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0) # (B,)
1081
- h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0) # (B,)
1082
 
1083
- size_emb = encode_size(
1084
- torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), # (B,2)
1085
- self.region
1086
- ).unsqueeze(1) # (B,1,C)
1087
 
1088
- # record boxes only for ALIVE rows
1089
  for i in alive.nonzero(as_tuple=False).flatten().tolist():
1090
  xl = (x_center[i] - w[i] / 2).item()
1091
  xr = (x_center[i] + w[i] / 2).item()
@@ -1103,7 +1098,6 @@ class MoondreamModel(nn.Module):
1103
  pos_ids[idx, 0] += 1
1104
  next_tok = logits.argmax(dim=-1)
1105
  else:
1106
- # point mode
1107
  for i in alive.nonzero(as_tuple=False).flatten().tolist():
1108
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1109
  idx = _advance_rows(alive)
@@ -1111,16 +1105,22 @@ class MoondreamModel(nn.Module):
1111
  pos_ids[idx, 0] += 1
1112
  next_tok = logits.argmax(dim=-1)
1113
 
1114
- # normalize next_tok to shape (B,)
1115
  while next_tok.dim() > 1:
1116
  next_tok = next_tok.squeeze(-1)
1117
 
 
1118
  counts[alive] += 1
1119
- finished_now = (next_tok == eos_id) | (counts >= max_objects)
 
 
 
 
1120
  alive &= ~finished_now
1121
 
1122
  return out
1123
 
 
1124
  def detect_multi(self, image, objects, settings=None):
1125
  if self.config.tokenizer.templates["detect"] is None:
1126
  raise NotImplementedError("Model does not support object detection.")
@@ -1132,17 +1132,17 @@ class MoondreamModel(nn.Module):
1132
 
1133
  lora = variant_state_dict(settings["variant"], device=self.device) if "variant" in settings else None
1134
 
1135
- last_hidden, next_token, pos_end = self._prefill_prompt_batched(
1136
  objects, enc.pos, lora=lora, temperature=0.0, top_p=0.0
1137
  )
1138
-
1139
  det_lists = self._generate_points_batched(
1140
- last_hidden, next_token, pos_end,
1141
  include_size=True,
1142
  max_objects=settings.get("max_objects", 50),
1143
  lora=lora,
1144
  )
1145
-
1146
  res = {}
1147
  for lab, lst in zip(objects, det_lists):
1148
  for d in lst:
 
943
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
944
 
945
 
946
+ def _prefill_prompt_batched(
947
+ self,
948
+ labels,
949
+ pos: int,
950
+ lora=None,
951
+ temperature: float = 0.0,
952
+ top_p: float = 0.0,
953
+ ):
954
  tpl = self.config.tokenizer.templates["detect"]
955
  if tpl is None:
956
  raise NotImplementedError("Model does not support object detection.")
 
960
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
961
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
962
  rows.append(t); lens.append(t.numel())
963
+
964
  B, T = len(rows), max(lens)
965
  eos = self.config.tokenizer.eos_id
966
 
967
+ # Pad with EOS in the tensor, but we will still start generation per-row at its own length
968
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
969
  for i, ids in enumerate(rows):
970
  prompt_ids[i, : ids.numel()] = ids
971
 
972
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
973
  torch._dynamo.mark_dynamic(prompt_emb, 1)
974
 
975
+ base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
976
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
977
 
978
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
979
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
980
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
981
 
982
+ # Gather last real token per row
983
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
984
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
985
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
986
 
987
  if temperature == 0.0:
988
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
989
  else:
990
  probs = torch.softmax(last_logits / temperature, dim=-1)
991
  probs = self._apply_top_p(probs, top_p)
992
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
993
+
994
+ # Per-row next positions (don’t force them all to pos+T)
995
+ pos_vec = (pos + torch.tensor(lens, device=self.device, dtype=torch.long)) # (B,)
996
 
997
+ return last_hidden, next_token, pos_vec
998
+
999
 
1000
 
1001
  def _generate_points_batched(
1002
  self,
1003
+ hidden, # (B,1,C)
1004
+ next_token, # (B,1) (unused for greedy)
1005
+ pos_vec, # (B,) next-free position per row
1006
  include_size: bool = True,
1007
  max_objects: int = 50,
1008
  lora=None,
 
1012
  device = self.device
1013
  out = [[] for _ in range(B)]
1014
  eos_id = self.config.tokenizer.eos_id
1015
+ coord_id = self.config.tokenizer.coord_id
1016
  max_ctx = self.config.text.max_context
1017
 
1018
+ # Build per-row masks/positions
1019
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1020
+ pos_ids = pos_vec.clone().view(B, 1) # (B,1)
1021
+ for i in range(B):
1022
+ p0 = int(pos_ids[i, 0].item())
1023
+ if p0 > 0:
1024
+ mask[i, 0, 0, :p0] = True
1025
 
 
 
 
 
1026
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
1027
  if use_soft_argmax:
1028
  probs = torch.softmax(logits, dim=-1)
 
1031
  idx = logits.argmax(dim=-1).to(torch.float32)
1032
  return idx / float(logits.size(-1) - 1)
1033
 
 
1034
  def _advance_rows(row_mask: torch.Tensor):
1035
  idx = row_mask.nonzero(as_tuple=False).flatten()
 
1036
  for i in idx.tolist():
1037
  col = int(pos_ids[i, 0].item())
1038
  mask[i, 0, 0, col] = True
 
1039
  return idx
1040
 
1041
  alive = torch.ones(B, dtype=torch.bool, device=device)
 
1043
 
1044
  with torch.inference_mode():
1045
  while alive.any() and (counts < max_objects).any():
1046
+ # -------- x --------
1047
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1048
+ if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1049
+ x_center = _argmax01(x_logits) # (B,)
1050
+ x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
 
 
 
 
 
1051
 
1052
  idx = _advance_rows(alive)
1053
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1054
  pos_ids[idx, 0] += 1
1055
 
1056
+ # -------- y --------
1057
  y_logits = decode_coordinate(hidden, self.region)
1058
+ if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
 
1059
  y_center = _argmax01(y_logits)
1060
+ y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
 
 
 
1061
 
1062
  idx = _advance_rows(alive)
1063
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1064
  pos_ids[idx, 0] += 1
1065
 
1066
  if include_size:
 
1067
  size_ret = decode_size(hidden, self.region)
1068
+ w_logits, h_logits = self._norm_size_logits(size_ret, B) # (B,C)
1069
 
1070
  if use_soft_argmax:
1071
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
 
1075
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1076
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1077
 
1078
+ w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1079
+ h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
 
1080
 
1081
+ size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
 
 
 
1082
 
1083
+ # record boxes only for rows still alive
1084
  for i in alive.nonzero(as_tuple=False).flatten().tolist():
1085
  xl = (x_center[i] - w[i] / 2).item()
1086
  xr = (x_center[i] + w[i] / 2).item()
 
1098
  pos_ids[idx, 0] += 1
1099
  next_tok = logits.argmax(dim=-1)
1100
  else:
 
1101
  for i in alive.nonzero(as_tuple=False).flatten().tolist():
1102
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1103
  idx = _advance_rows(alive)
 
1105
  pos_ids[idx, 0] += 1
1106
  next_tok = logits.argmax(dim=-1)
1107
 
1108
+ # normalize next_tok to (B,)
1109
  while next_tok.dim() > 1:
1110
  next_tok = next_tok.squeeze(-1)
1111
 
1112
+ # we added exactly one object/point to all alive rows
1113
  counts[alive] += 1
1114
+
1115
+ # GRAMMAR STOP: only continue if the model asks to start another coord;
1116
+ # otherwise stop row (covers EOS or any non-coord token).
1117
+ continue_mask = (next_tok == coord_id)
1118
+ finished_now = (~continue_mask) | (counts >= max_objects)
1119
  alive &= ~finished_now
1120
 
1121
  return out
1122
 
1123
+
1124
  def detect_multi(self, image, objects, settings=None):
1125
  if self.config.tokenizer.templates["detect"] is None:
1126
  raise NotImplementedError("Model does not support object detection.")
 
1132
 
1133
  lora = variant_state_dict(settings["variant"], device=self.device) if "variant" in settings else None
1134
 
1135
+ last_hidden, next_token, pos_vec = self._prefill_prompt_batched(
1136
  objects, enc.pos, lora=lora, temperature=0.0, top_p=0.0
1137
  )
1138
+
1139
  det_lists = self._generate_points_batched(
1140
+ last_hidden, next_token, pos_vec,
1141
  include_size=True,
1142
  max_objects=settings.get("max_objects", 50),
1143
  lora=lora,
1144
  )
1145
+
1146
  res = {}
1147
  for lab, lst in zip(objects, det_lists):
1148
  for d in lst: