Update moondream.py
Browse filesfix: change pad token
- moondream.py +17 -14
moondream.py
CHANGED
|
@@ -945,7 +945,7 @@ class MoondreamModel(nn.Module):
|
|
| 945 |
|
| 946 |
def _prefill_prompt_batched(
|
| 947 |
self,
|
| 948 |
-
labels
|
| 949 |
pos: int,
|
| 950 |
lora=None,
|
| 951 |
temperature: float = 0.0,
|
|
@@ -955,7 +955,7 @@ class MoondreamModel(nn.Module):
|
|
| 955 |
if tpl is None:
|
| 956 |
raise NotImplementedError("Model does not support object detection.")
|
| 957 |
|
| 958 |
-
# 1)
|
| 959 |
rows_ids, lens = [], []
|
| 960 |
for lab in labels:
|
| 961 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
|
@@ -966,7 +966,7 @@ class MoondreamModel(nn.Module):
|
|
| 966 |
B = len(rows_ids)
|
| 967 |
T = max(lens)
|
| 968 |
|
| 969 |
-
# 2) Embed
|
| 970 |
embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list[(Li, C)]
|
| 971 |
padded = []
|
| 972 |
for e, L in zip(embs, lens):
|
|
@@ -977,30 +977,33 @@ class MoondreamModel(nn.Module):
|
|
| 977 |
prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
|
| 978 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 979 |
|
| 980 |
-
# 3) Prefill over the shared image prefix [pos : pos
|
| 981 |
-
base = self.attn_mask[:, :, pos:pos + T, :]
|
| 982 |
-
mask = base.expand(B, -1, -1, -1).contiguous()
|
| 983 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
| 984 |
-
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
|
| 985 |
-
logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
|
| 986 |
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 991 |
|
| 992 |
if temperature == 0.0:
|
| 993 |
-
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,
|
| 994 |
else:
|
| 995 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 996 |
probs = self._apply_top_p(probs, top_p)
|
| 997 |
-
next_token = torch.multinomial(probs, num_samples=1) # (B,
|
| 998 |
|
| 999 |
pos_end = int(pos + T)
|
| 1000 |
return last_hidden, next_token, pos_end
|
| 1001 |
|
| 1002 |
|
| 1003 |
|
|
|
|
| 1004 |
def _generate_points_batched(
|
| 1005 |
self,
|
| 1006 |
hidden, # (B,1,C) - last token hidden state per row
|
|
|
|
| 945 |
|
| 946 |
def _prefill_prompt_batched(
|
| 947 |
self,
|
| 948 |
+
labels,
|
| 949 |
pos: int,
|
| 950 |
lora=None,
|
| 951 |
temperature: float = 0.0,
|
|
|
|
| 955 |
if tpl is None:
|
| 956 |
raise NotImplementedError("Model does not support object detection.")
|
| 957 |
|
| 958 |
+
# 1) Tokenize each label (variable lengths Li)
|
| 959 |
rows_ids, lens = [], []
|
| 960 |
for lab in labels:
|
| 961 |
ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
|
|
|
|
| 966 |
B = len(rows_ids)
|
| 967 |
T = max(lens)
|
| 968 |
|
| 969 |
+
# 2) Embed and LEFT-pad each row with its first token embedding
|
| 970 |
embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list[(Li, C)]
|
| 971 |
padded = []
|
| 972 |
for e, L in zip(embs, lens):
|
|
|
|
| 977 |
prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
|
| 978 |
torch._dynamo.mark_dynamic(prompt_emb, 1)
|
| 979 |
|
| 980 |
+
# 3) Prefill over the shared image prefix [pos : pos+T)
|
| 981 |
+
base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
|
| 982 |
+
mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
|
| 983 |
pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
|
|
|
|
|
|
|
| 984 |
|
| 985 |
+
hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
|
| 986 |
+
logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
|
| 987 |
+
|
| 988 |
+
# *** IMPORTANT: take the tail position for every row ***
|
| 989 |
+
last_idx = torch.full((B,), T - 1, device=self.device, dtype=torch.long)
|
| 990 |
+
|
| 991 |
+
last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B,1,C)
|
| 992 |
+
last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B,V)
|
| 993 |
|
| 994 |
if temperature == 0.0:
|
| 995 |
+
next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
|
| 996 |
else:
|
| 997 |
probs = torch.softmax(last_logits / temperature, dim=-1)
|
| 998 |
probs = self._apply_top_p(probs, top_p)
|
| 999 |
+
next_token = torch.multinomial(probs, num_samples=1) # (B,1)
|
| 1000 |
|
| 1001 |
pos_end = int(pos + T)
|
| 1002 |
return last_hidden, next_token, pos_end
|
| 1003 |
|
| 1004 |
|
| 1005 |
|
| 1006 |
+
|
| 1007 |
def _generate_points_batched(
|
| 1008 |
self,
|
| 1009 |
hidden, # (B,1,C) - last token hidden state per row
|