HV-Khurdula commited on
Commit
13973a3
·
verified ·
1 Parent(s): 15e66ee

Update moondream.py

Browse files

fix: change pad token

Files changed (1) hide show
  1. moondream.py +17 -14
moondream.py CHANGED
@@ -945,7 +945,7 @@ class MoondreamModel(nn.Module):
945
 
946
  def _prefill_prompt_batched(
947
  self,
948
- labels: List[str],
949
  pos: int,
950
  lora=None,
951
  temperature: float = 0.0,
@@ -955,7 +955,7 @@ class MoondreamModel(nn.Module):
955
  if tpl is None:
956
  raise NotImplementedError("Model does not support object detection.")
957
 
958
- # 1) Build token ids for each label (variable length)
959
  rows_ids, lens = [], []
960
  for lab in labels:
961
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
@@ -966,7 +966,7 @@ class MoondreamModel(nn.Module):
966
  B = len(rows_ids)
967
  T = max(lens)
968
 
969
- # 2) Embed then LEFT-pad each row to length T using the row’s first token embedding
970
  embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list[(Li, C)]
971
  padded = []
972
  for e, L in zip(embs, lens):
@@ -977,30 +977,33 @@ class MoondreamModel(nn.Module):
977
  prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
978
  torch._dynamo.mark_dynamic(prompt_emb, 1)
979
 
980
- # 3) Prefill over the shared image prefix [pos : pos + T)
981
- base = self.attn_mask[:, :, pos:pos + T, :] # (1, 1, T, K)
982
- mask = base.expand(B, -1, -1, -1).contiguous() # (B, 1, T, K)
983
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
984
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
985
- logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
986
 
987
- # **FIX**: After left-padding, the last real token sits at T-1 for every row.
988
- last_idx = torch.full((B,), T - 1, device=self.device, dtype=torch.long) # (B,)
989
- last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B, 1, C)
990
- last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B, V)
 
 
 
 
991
 
992
  if temperature == 0.0:
993
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B, 1)
994
  else:
995
  probs = torch.softmax(last_logits / temperature, dim=-1)
996
  probs = self._apply_top_p(probs, top_p)
997
- next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
998
 
999
  pos_end = int(pos + T)
1000
  return last_hidden, next_token, pos_end
1001
 
1002
 
1003
 
 
1004
  def _generate_points_batched(
1005
  self,
1006
  hidden, # (B,1,C) - last token hidden state per row
 
945
 
946
  def _prefill_prompt_batched(
947
  self,
948
+ labels,
949
  pos: int,
950
  lora=None,
951
  temperature: float = 0.0,
 
955
  if tpl is None:
956
  raise NotImplementedError("Model does not support object detection.")
957
 
958
+ # 1) Tokenize each label (variable lengths Li)
959
  rows_ids, lens = [], []
960
  for lab in labels:
961
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
 
966
  B = len(rows_ids)
967
  T = max(lens)
968
 
969
+ # 2) Embed and LEFT-pad each row with its first token embedding
970
  embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list[(Li, C)]
971
  padded = []
972
  for e, L in zip(embs, lens):
 
977
  prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
978
  torch._dynamo.mark_dynamic(prompt_emb, 1)
979
 
980
+ # 3) Prefill over the shared image prefix [pos : pos+T)
981
+ base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
982
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
983
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
 
 
984
 
985
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
986
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
987
+
988
+ # *** IMPORTANT: take the tail position for every row ***
989
+ last_idx = torch.full((B,), T - 1, device=self.device, dtype=torch.long)
990
+
991
+ last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B,1,C)
992
+ last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B,V)
993
 
994
  if temperature == 0.0:
995
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
996
  else:
997
  probs = torch.softmax(last_logits / temperature, dim=-1)
998
  probs = self._apply_top_p(probs, top_p)
999
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
1000
 
1001
  pos_end = int(pos + T)
1002
  return last_hidden, next_token, pos_end
1003
 
1004
 
1005
 
1006
+
1007
  def _generate_points_batched(
1008
  self,
1009
  hidden, # (B,1,C) - last token hidden state per row