Atharva Mete
commited on
Commit
·
57b4d23
1
Parent(s):
303e3cf
vla added but giving nans in loss
Browse files- added_tokens.json +2 -0
- config.json +2 -0
- config_molmo.py +4 -0
- modeling_molmo.py +53 -2
- preprocessing_molmo.py +20 -4
- special_tokens_map.json +3 -1
- tokenizer_config.json +19 -1
added_tokens.json
CHANGED
@@ -7,6 +7,8 @@
|
|
7 |
"<|im_end|>": 151645,
|
8 |
"<|im_start|>": 151644,
|
9 |
"<|image|>": 152068,
|
|
|
|
|
10 |
"|<EXTRA_TOKENS_0>|": 151646,
|
11 |
"|<EXTRA_TOKENS_100>|": 151746,
|
12 |
"|<EXTRA_TOKENS_101>|": 151747,
|
|
|
7 |
"<|im_end|>": 151645,
|
8 |
"<|im_start|>": 151644,
|
9 |
"<|image|>": 152068,
|
10 |
+
"<|proprio|>": 152069,
|
11 |
+
"<|skill|>": 152070,
|
12 |
"|<EXTRA_TOKENS_0>|": 151646,
|
13 |
"|<EXTRA_TOKENS_100>|": 151746,
|
14 |
"|<EXTRA_TOKENS_101>|": 151747,
|
config.json
CHANGED
@@ -28,5 +28,7 @@
|
|
28 |
"use_cache": true,
|
29 |
"use_position_ids": true,
|
30 |
"vocab_size": 152064,
|
|
|
|
|
31 |
"weight_tying": false
|
32 |
}
|
|
|
28 |
"use_cache": true,
|
29 |
"use_position_ids": true,
|
30 |
"vocab_size": 152064,
|
31 |
+
"skill_vocab_size": 1000,
|
32 |
+
"additional_vocab_size": 128,
|
33 |
"weight_tying": false
|
34 |
}
|
config_molmo.py
CHANGED
@@ -9,6 +9,8 @@ class MolmoConfig(PretrainedConfig):
|
|
9 |
|
10 |
def __init__(
|
11 |
self,
|
|
|
|
|
12 |
vocab_size=50304,
|
13 |
embedding_size=50304,
|
14 |
hidden_size=4096,
|
@@ -31,6 +33,8 @@ class MolmoConfig(PretrainedConfig):
|
|
31 |
layer_norm_type: str="rms",
|
32 |
**kwargs,
|
33 |
):
|
|
|
|
|
34 |
self.vocab_size = vocab_size
|
35 |
self.embedding_size = embedding_size
|
36 |
self.max_position_embeddings = max_position_embeddings
|
|
|
9 |
|
10 |
def __init__(
|
11 |
self,
|
12 |
+
skill_vocab_size=1000,
|
13 |
+
additional_vocab_size=128,
|
14 |
vocab_size=50304,
|
15 |
embedding_size=50304,
|
16 |
hidden_size=4096,
|
|
|
33 |
layer_norm_type: str="rms",
|
34 |
**kwargs,
|
35 |
):
|
36 |
+
self.skill_vocab_size = skill_vocab_size
|
37 |
+
self.additional_vocab_size = additional_vocab_size
|
38 |
self.vocab_size = vocab_size
|
39 |
self.embedding_size = embedding_size
|
40 |
self.max_position_embeddings = max_position_embeddings
|
modeling_molmo.py
CHANGED
@@ -541,6 +541,7 @@ class Embedding(nn.Module):
|
|
541 |
self,
|
542 |
num_embeddings: int,
|
543 |
num_new_embeddings: int,
|
|
|
544 |
features: int,
|
545 |
device: Union[str, torch.device],
|
546 |
initializer_range: float = 0.02,
|
@@ -555,13 +556,17 @@ class Embedding(nn.Module):
|
|
555 |
self.new_embedding = nn.Parameter(
|
556 |
torch.zeros(num_new_embeddings, features, device=device),
|
557 |
)
|
|
|
|
|
|
|
558 |
|
559 |
def reset_parameters(self):
|
560 |
nn.init.normal_(self.embedding, std=self.initializer_range)
|
561 |
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
|
|
|
562 |
|
563 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
564 |
-
return F.embedding(x, torch.cat([self.embedding, self.new_embedding], dim=0))
|
565 |
|
566 |
|
567 |
class Dropout(nn.Dropout):
|
@@ -681,6 +686,7 @@ class FullMolmoConfig:
|
|
681 |
initializer_range: float = 0.02
|
682 |
normalize_input_embeds: bool = False
|
683 |
use_position_ids: bool = True
|
|
|
684 |
|
685 |
@property
|
686 |
def effective_n_kv_heads(self) -> int:
|
@@ -1695,6 +1701,7 @@ class Molmo(nn.Module):
|
|
1695 |
wte = Embedding(
|
1696 |
config.embedding_size or config.vocab_size,
|
1697 |
config.additional_vocab_size,
|
|
|
1698 |
config.d_model,
|
1699 |
device=config.init_device,
|
1700 |
initializer_range=config.initializer_range,
|
@@ -1734,6 +1741,16 @@ class Molmo(nn.Module):
|
|
1734 |
)
|
1735 |
}
|
1736 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1737 |
|
1738 |
self.vision_backbone: Optional[OLMoVisionBackbone] = None
|
1739 |
if config.vision_backbone is not None:
|
@@ -1741,6 +1758,11 @@ class Molmo(nn.Module):
|
|
1741 |
|
1742 |
self.__num_fwd_flops: Optional[int] = None
|
1743 |
|
|
|
|
|
|
|
|
|
|
|
1744 |
def reset_parameters(self):
|
1745 |
if self.vision_backbone is not None:
|
1746 |
self.vision_backbone.reset_parameters()
|
@@ -1778,12 +1800,15 @@ class Molmo(nn.Module):
|
|
1778 |
image_masks: Optional[torch.Tensor] = None,
|
1779 |
image_input_idx: Optional[torch.Tensor] = None,
|
1780 |
subsegment_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
1781 |
position_ids: Optional[torch.Tensor] = None,
|
1782 |
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
1783 |
use_cache: bool = False,
|
1784 |
last_logits_only: bool = False,
|
1785 |
output_hidden_states: Optional[bool] = None,
|
1786 |
append_last_valid_logits: Optional[torch.Tensor] = None,
|
|
|
1787 |
) -> ModelOutput:
|
1788 |
"""
|
1789 |
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
@@ -1880,6 +1905,9 @@ class Molmo(nn.Module):
|
|
1880 |
image_features = image_features.to(x.device)
|
1881 |
|
1882 |
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
|
|
|
|
|
|
|
1883 |
|
1884 |
if not self.config.rope:
|
1885 |
# Get positional embeddings.
|
@@ -1997,7 +2025,14 @@ class Molmo(nn.Module):
|
|
1997 |
if self.config.weight_tying:
|
1998 |
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
1999 |
else:
|
2000 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2001 |
if self.config.scale_logits:
|
2002 |
logits.mul_(1 / math.sqrt(self.config.d_model))
|
2003 |
|
@@ -2039,6 +2074,7 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2039 |
mlp_hidden_size=config.intermediate_size,
|
2040 |
n_layers=config.num_hidden_layers,
|
2041 |
additional_vocab_size=128,
|
|
|
2042 |
n_heads=config.num_attention_heads,
|
2043 |
n_kv_heads=config.num_key_value_heads,
|
2044 |
rope_theta=config.rope_theta,
|
@@ -2080,6 +2116,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2080 |
image_masks: Optional[torch.Tensor] = None,
|
2081 |
image_input_idx: Optional[torch.Tensor] = None,
|
2082 |
subsegment_ids: Optional[torch.Tensor] = None,
|
|
|
|
|
2083 |
position_ids: Optional[torch.Tensor] = None,
|
2084 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
2085 |
labels: Optional[torch.LongTensor] = None,
|
@@ -2113,6 +2151,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2113 |
image_masks=image_masks,
|
2114 |
image_input_idx=image_input_idx,
|
2115 |
subsegment_ids=subsegment_ids,
|
|
|
|
|
2116 |
position_ids=position_ids,
|
2117 |
past_key_values=past_key_values,
|
2118 |
use_cache=use_cache,
|
@@ -2185,6 +2225,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2185 |
images = batch.get("images")
|
2186 |
image_masks = batch.get("image_masks")
|
2187 |
image_input_idx = batch.get("image_input_idx")
|
|
|
|
|
2188 |
|
2189 |
# Validate inputs.
|
2190 |
input_ids = batch["input_ids"]
|
@@ -2217,6 +2259,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2217 |
image_masks=image_masks,
|
2218 |
image_input_idx=image_input_idx,
|
2219 |
position_ids=position_ids,
|
|
|
|
|
2220 |
append_last_valid_logits=append_last_valid_logits,
|
2221 |
**kwargs,
|
2222 |
)
|
@@ -2235,6 +2279,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2235 |
images = kwargs.get("images")
|
2236 |
image_masks = kwargs.get("image_masks")
|
2237 |
image_input_idx = kwargs.get("image_input_idx")
|
|
|
|
|
2238 |
position_ids = kwargs.get("position_ids")
|
2239 |
append_last_valid_logits = kwargs.get("append_last_valid_logits")
|
2240 |
model_inputs = {
|
@@ -2250,6 +2296,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2250 |
model_inputs["image_masks"] = image_masks
|
2251 |
model_inputs["image_input_idx"] = image_input_idx
|
2252 |
model_inputs["append_last_valid_logits"] = append_last_valid_logits
|
|
|
|
|
2253 |
else:
|
2254 |
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
2255 |
|
@@ -2272,6 +2320,9 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2272 |
del model_kwargs["images"]
|
2273 |
del model_kwargs["image_masks"]
|
2274 |
del model_kwargs["image_input_idx"]
|
|
|
|
|
|
|
2275 |
cache_name, cache = super()._extract_past_from_model_output(outputs)
|
2276 |
model_kwargs[cache_name] = cache
|
2277 |
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
|
|
541 |
self,
|
542 |
num_embeddings: int,
|
543 |
num_new_embeddings: int,
|
544 |
+
num_skill_embeddings: int,
|
545 |
features: int,
|
546 |
device: Union[str, torch.device],
|
547 |
initializer_range: float = 0.02,
|
|
|
556 |
self.new_embedding = nn.Parameter(
|
557 |
torch.zeros(num_new_embeddings, features, device=device),
|
558 |
)
|
559 |
+
self.skill_embedding = nn.Parameter(
|
560 |
+
torch.zeros(num_skill_embeddings, features, device=device),
|
561 |
+
)
|
562 |
|
563 |
def reset_parameters(self):
|
564 |
nn.init.normal_(self.embedding, std=self.initializer_range)
|
565 |
nn.init.normal_(self.new_embedding, std=self.new_embed_initializer_range)
|
566 |
+
nn.init.normal_(self.skill_embedding, std=self.new_embed_initializer_range)
|
567 |
|
568 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
569 |
+
return F.embedding(x, torch.cat([self.embedding, self.new_embedding, self.skill_embedding], dim=0))
|
570 |
|
571 |
|
572 |
class Dropout(nn.Dropout):
|
|
|
686 |
initializer_range: float = 0.02
|
687 |
normalize_input_embeds: bool = False
|
688 |
use_position_ids: bool = True
|
689 |
+
skill_vocab_size: int = 1000
|
690 |
|
691 |
@property
|
692 |
def effective_n_kv_heads(self) -> int:
|
|
|
1701 |
wte = Embedding(
|
1702 |
config.embedding_size or config.vocab_size,
|
1703 |
config.additional_vocab_size,
|
1704 |
+
config.skill_vocab_size,
|
1705 |
config.d_model,
|
1706 |
device=config.init_device,
|
1707 |
initializer_range=config.initializer_range,
|
|
|
1741 |
)
|
1742 |
}
|
1743 |
)
|
1744 |
+
self.transformer.update(
|
1745 |
+
{
|
1746 |
+
"skill_ff_out": nn.Linear(
|
1747 |
+
config.d_model,
|
1748 |
+
config.skill_vocab_size,
|
1749 |
+
bias=config.include_bias,
|
1750 |
+
device=config.init_device,
|
1751 |
+
)
|
1752 |
+
}
|
1753 |
+
)
|
1754 |
|
1755 |
self.vision_backbone: Optional[OLMoVisionBackbone] = None
|
1756 |
if config.vision_backbone is not None:
|
|
|
1758 |
|
1759 |
self.__num_fwd_flops: Optional[int] = None
|
1760 |
|
1761 |
+
self.total_vocab_size = config.vocab_size + config.additional_vocab_size + config.skill_vocab_size
|
1762 |
+
torch.nn.init.xavier_uniform_(self.transformer.skill_ff_out.weight)
|
1763 |
+
if self.transformer.skill_ff_out.bias is not None:
|
1764 |
+
torch.nn.init.zeros_(self.transformer.skill_ff_out.bias)
|
1765 |
+
|
1766 |
def reset_parameters(self):
|
1767 |
if self.vision_backbone is not None:
|
1768 |
self.vision_backbone.reset_parameters()
|
|
|
1800 |
image_masks: Optional[torch.Tensor] = None,
|
1801 |
image_input_idx: Optional[torch.Tensor] = None,
|
1802 |
subsegment_ids: Optional[torch.Tensor] = None,
|
1803 |
+
proprio_embeds: Optional[torch.Tensor] = None,
|
1804 |
+
proprio_idx: Optional[torch.Tensor] = None,
|
1805 |
position_ids: Optional[torch.Tensor] = None,
|
1806 |
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
1807 |
use_cache: bool = False,
|
1808 |
last_logits_only: bool = False,
|
1809 |
output_hidden_states: Optional[bool] = None,
|
1810 |
append_last_valid_logits: Optional[torch.Tensor] = None,
|
1811 |
+
mode: Optional[str] = "vla",
|
1812 |
) -> ModelOutput:
|
1813 |
"""
|
1814 |
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
|
|
1905 |
image_features = image_features.to(x.device)
|
1906 |
|
1907 |
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
|
1908 |
+
|
1909 |
+
if proprio_embeds is not None:
|
1910 |
+
x[batch_idx, proprio_idx] += proprio_embeds
|
1911 |
|
1912 |
if not self.config.rope:
|
1913 |
# Get positional embeddings.
|
|
|
2025 |
if self.config.weight_tying:
|
2026 |
logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
|
2027 |
else:
|
2028 |
+
if mode == "vla":
|
2029 |
+
logits = self.transformer.skill_ff_out(x)
|
2030 |
+
# this little trick allows use to use HF generate() while decoding
|
2031 |
+
if use_cache:
|
2032 |
+
filler_logits = torch.full((x.shape[0], x.shape[1], self.total_vocab_size-self.config.skill_vocab_size), -math.inf, device=logits.device)
|
2033 |
+
logits = torch.cat([filler_logits, logits], dim=-1) # type: ignore
|
2034 |
+
else:
|
2035 |
+
logits = self.transformer.ff_out(x) # type: ignore
|
2036 |
if self.config.scale_logits:
|
2037 |
logits.mul_(1 / math.sqrt(self.config.d_model))
|
2038 |
|
|
|
2074 |
mlp_hidden_size=config.intermediate_size,
|
2075 |
n_layers=config.num_hidden_layers,
|
2076 |
additional_vocab_size=128,
|
2077 |
+
skill_vocab_size=config.skill_vocab_size,
|
2078 |
n_heads=config.num_attention_heads,
|
2079 |
n_kv_heads=config.num_key_value_heads,
|
2080 |
rope_theta=config.rope_theta,
|
|
|
2116 |
image_masks: Optional[torch.Tensor] = None,
|
2117 |
image_input_idx: Optional[torch.Tensor] = None,
|
2118 |
subsegment_ids: Optional[torch.Tensor] = None,
|
2119 |
+
proprio_embeds: Optional[torch.Tensor] = None,
|
2120 |
+
proprio_idx: Optional[torch.Tensor] = None,
|
2121 |
position_ids: Optional[torch.Tensor] = None,
|
2122 |
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
2123 |
labels: Optional[torch.LongTensor] = None,
|
|
|
2151 |
image_masks=image_masks,
|
2152 |
image_input_idx=image_input_idx,
|
2153 |
subsegment_ids=subsegment_ids,
|
2154 |
+
proprio_embeds=proprio_embeds,
|
2155 |
+
proprio_idx=proprio_idx,
|
2156 |
position_ids=position_ids,
|
2157 |
past_key_values=past_key_values,
|
2158 |
use_cache=use_cache,
|
|
|
2225 |
images = batch.get("images")
|
2226 |
image_masks = batch.get("image_masks")
|
2227 |
image_input_idx = batch.get("image_input_idx")
|
2228 |
+
proprio_embeds = batch.get("proprio_embeds")
|
2229 |
+
proprio_idx = batch.get("proprio_idx")
|
2230 |
|
2231 |
# Validate inputs.
|
2232 |
input_ids = batch["input_ids"]
|
|
|
2259 |
image_masks=image_masks,
|
2260 |
image_input_idx=image_input_idx,
|
2261 |
position_ids=position_ids,
|
2262 |
+
proprio_embeds=proprio_embeds,
|
2263 |
+
proprio_idx=proprio_idx,
|
2264 |
append_last_valid_logits=append_last_valid_logits,
|
2265 |
**kwargs,
|
2266 |
)
|
|
|
2279 |
images = kwargs.get("images")
|
2280 |
image_masks = kwargs.get("image_masks")
|
2281 |
image_input_idx = kwargs.get("image_input_idx")
|
2282 |
+
proprio_embeds = kwargs.get("proprio_embeds")
|
2283 |
+
proprio_idx = kwargs.get("proprio_idx")
|
2284 |
position_ids = kwargs.get("position_ids")
|
2285 |
append_last_valid_logits = kwargs.get("append_last_valid_logits")
|
2286 |
model_inputs = {
|
|
|
2296 |
model_inputs["image_masks"] = image_masks
|
2297 |
model_inputs["image_input_idx"] = image_input_idx
|
2298 |
model_inputs["append_last_valid_logits"] = append_last_valid_logits
|
2299 |
+
model_inputs["proprio_embeds"] = proprio_embeds
|
2300 |
+
model_inputs["proprio_idx"] = proprio_idx
|
2301 |
else:
|
2302 |
model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
|
2303 |
|
|
|
2320 |
del model_kwargs["images"]
|
2321 |
del model_kwargs["image_masks"]
|
2322 |
del model_kwargs["image_input_idx"]
|
2323 |
+
if "proprio_embeds" in model_kwargs:
|
2324 |
+
del model_kwargs["proprio_embeds"]
|
2325 |
+
del model_kwargs["proprio_idx"]
|
2326 |
cache_name, cache = super()._extract_past_from_model_output(outputs)
|
2327 |
model_kwargs[cache_name] = cache
|
2328 |
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
preprocessing_molmo.py
CHANGED
@@ -28,7 +28,7 @@ from transformers.utils import logging
|
|
28 |
|
29 |
from transformers import AutoTokenizer
|
30 |
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
31 |
-
|
32 |
|
33 |
logger = logging.get_logger(__name__)
|
34 |
|
@@ -38,9 +38,14 @@ DEFAULT_IM_START_TOKEN = f"<im_start>"
|
|
38 |
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
39 |
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
40 |
IMAGE_PROMPT = "<|image|>"
|
|
|
|
|
41 |
|
42 |
-
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT)
|
43 |
|
|
|
|
|
|
|
44 |
|
45 |
def get_special_token_ids(tokenizer):
|
46 |
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
@@ -72,7 +77,7 @@ class MolmoProcessorKwargs(ProcessingKwargs, total=False):
|
|
72 |
"text_kwargs": {
|
73 |
"style": "long_caption",
|
74 |
"system_prompt": "none",
|
75 |
-
"message_format": "
|
76 |
"always_start_with_space": True,
|
77 |
"sequence_length": 1536,
|
78 |
"padding": False,
|
@@ -97,11 +102,14 @@ class MolmoProcessor(ProcessorMixin):
|
|
97 |
self._special_tokens = get_special_token_ids(self.tokenizer)
|
98 |
return self._special_tokens
|
99 |
|
100 |
-
def get_tokens_input(self, prompt, message_format, always_start_with_space):
|
101 |
if message_format == "none" or message_format is None:
|
102 |
pass
|
103 |
elif message_format == "role":
|
104 |
prompt = "User: " + prompt + " Assistant:"
|
|
|
|
|
|
|
105 |
else:
|
106 |
raise NotImplementedError(f"Message format {message_format} not implemented")
|
107 |
|
@@ -116,6 +124,7 @@ class MolmoProcessor(ProcessorMixin):
|
|
116 |
self,
|
117 |
text: TextInput = None,
|
118 |
images: ImageInput = None,
|
|
|
119 |
*,
|
120 |
tokens: Optional[PreTokenizedInput] = None,
|
121 |
**kwargs: Unpack[MolmoProcessorKwargs],
|
@@ -126,14 +135,18 @@ class MolmoProcessor(ProcessorMixin):
|
|
126 |
**kwargs,
|
127 |
)
|
128 |
|
|
|
|
|
129 |
if tokens is None:
|
130 |
tokens = self.get_tokens_input(
|
131 |
text,
|
132 |
output_kwargs["text_kwargs"]["message_format"],
|
133 |
output_kwargs["text_kwargs"]["always_start_with_space"],
|
|
|
134 |
)
|
135 |
|
136 |
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
|
|
137 |
|
138 |
if images is not None:
|
139 |
if not isinstance(images, (list, tuple)):
|
@@ -182,6 +195,9 @@ class MolmoProcessor(ProcessorMixin):
|
|
182 |
# Shift patch mapping up by one since we added BOS
|
183 |
image_input_idx = out["image_input_idx"]
|
184 |
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
|
|
|
|
|
|
185 |
|
186 |
for k, v in out.items():
|
187 |
out[k] = torch.from_numpy(v)
|
|
|
28 |
|
29 |
from transformers import AutoTokenizer
|
30 |
from .image_preprocessing_molmo import MolmoImagesKwargs, MolmoImageProcessor
|
31 |
+
from typing import List, Union
|
32 |
|
33 |
logger = logging.get_logger(__name__)
|
34 |
|
|
|
38 |
DEFAULT_IM_END_TOKEN = f"<im_end>"
|
39 |
DEFAULT_IM_COL_TOKEN = f"<im_col>"
|
40 |
IMAGE_PROMPT = "<|image|>"
|
41 |
+
PROPRIO_PROMPT = "<|proprio|>"
|
42 |
+
SKILL_PROMPT = "<|skill|>"
|
43 |
|
44 |
+
EXTRA_TOKENS = (DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_COL_TOKEN, IMAGE_PROMPT, PROPRIO_PROMPT, SKILL_PROMPT)
|
45 |
|
46 |
+
ProprioInput = Union[
|
47 |
+
np.ndarray, "torch.Tensor", List[np.ndarray], List["torch.Tensor"]
|
48 |
+
]
|
49 |
|
50 |
def get_special_token_ids(tokenizer):
|
51 |
ids = tokenizer.encode("".join(EXTRA_TOKENS), add_special_tokens=False)
|
|
|
77 |
"text_kwargs": {
|
78 |
"style": "long_caption",
|
79 |
"system_prompt": "none",
|
80 |
+
"message_format": "robot",
|
81 |
"always_start_with_space": True,
|
82 |
"sequence_length": 1536,
|
83 |
"padding": False,
|
|
|
102 |
self._special_tokens = get_special_token_ids(self.tokenizer)
|
103 |
return self._special_tokens
|
104 |
|
105 |
+
def get_tokens_input(self, prompt, message_format, always_start_with_space, num_proprio):
|
106 |
if message_format == "none" or message_format is None:
|
107 |
pass
|
108 |
elif message_format == "role":
|
109 |
prompt = "User: " + prompt + " Assistant:"
|
110 |
+
elif message_format == "robot":
|
111 |
+
# this adds proprio observations after the prompt
|
112 |
+
prompt = "User: " + prompt + PROPRIO_PROMPT*num_proprio + " Assistant:"
|
113 |
else:
|
114 |
raise NotImplementedError(f"Message format {message_format} not implemented")
|
115 |
|
|
|
124 |
self,
|
125 |
text: TextInput = None,
|
126 |
images: ImageInput = None,
|
127 |
+
proprio: ProprioInput = None,
|
128 |
*,
|
129 |
tokens: Optional[PreTokenizedInput] = None,
|
130 |
**kwargs: Unpack[MolmoProcessorKwargs],
|
|
|
135 |
**kwargs,
|
136 |
)
|
137 |
|
138 |
+
num_proprio = len(proprio) if proprio is not None else 0
|
139 |
+
|
140 |
if tokens is None:
|
141 |
tokens = self.get_tokens_input(
|
142 |
text,
|
143 |
output_kwargs["text_kwargs"]["message_format"],
|
144 |
output_kwargs["text_kwargs"]["always_start_with_space"],
|
145 |
+
num_proprio
|
146 |
)
|
147 |
|
148 |
image_token_id = self.special_token_ids[IMAGE_PROMPT]
|
149 |
+
proprio_token_id = self.special_token_ids[PROPRIO_PROMPT]
|
150 |
|
151 |
if images is not None:
|
152 |
if not isinstance(images, (list, tuple)):
|
|
|
195 |
# Shift patch mapping up by one since we added BOS
|
196 |
image_input_idx = out["image_input_idx"]
|
197 |
out["image_input_idx"] = np.where(image_input_idx < 0, image_input_idx, image_input_idx + 1)
|
198 |
+
|
199 |
+
proprio_idx = np.where(out["input_ids"] == proprio_token_id)[0]
|
200 |
+
out["proprio_idx"] = proprio_idx
|
201 |
|
202 |
for k, v in out.items():
|
203 |
out[k] = torch.from_numpy(v)
|
special_tokens_map.json
CHANGED
@@ -422,7 +422,9 @@
|
|
422 |
"<im_end>",
|
423 |
"<im_patch>",
|
424 |
"<im_col>",
|
425 |
-
"<|image|>"
|
|
|
|
|
426 |
],
|
427 |
"eos_token": {
|
428 |
"content": "<|endoftext|>",
|
|
|
422 |
"<im_end>",
|
423 |
"<im_patch>",
|
424 |
"<im_col>",
|
425 |
+
"<|image|>",
|
426 |
+
"<|proprio|>",
|
427 |
+
"<|skill|>"
|
428 |
],
|
429 |
"eos_token": {
|
430 |
"content": "<|endoftext|>",
|
tokenizer_config.json
CHANGED
@@ -3408,6 +3408,22 @@
|
|
3408 |
"rstrip": false,
|
3409 |
"single_word": false,
|
3410 |
"special": true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3411 |
}
|
3412 |
},
|
3413 |
"additional_special_tokens": [
|
@@ -3833,7 +3849,9 @@
|
|
3833 |
"<im_end>",
|
3834 |
"<im_patch>",
|
3835 |
"<im_col>",
|
3836 |
-
"<|image|>"
|
|
|
|
|
3837 |
],
|
3838 |
"auto_map": {
|
3839 |
"AutoProcessor": "preprocessing_molmo.MolmoProcessor"
|
|
|
3408 |
"rstrip": false,
|
3409 |
"single_word": false,
|
3410 |
"special": true
|
3411 |
+
},
|
3412 |
+
"152069": {
|
3413 |
+
"content": "<|proprio|>",
|
3414 |
+
"lstrip": false,
|
3415 |
+
"normalized": false,
|
3416 |
+
"rstrip": false,
|
3417 |
+
"single_word": false,
|
3418 |
+
"special": true
|
3419 |
+
},
|
3420 |
+
"152070": {
|
3421 |
+
"content": "<|skill|>",
|
3422 |
+
"lstrip": false,
|
3423 |
+
"normalized": false,
|
3424 |
+
"rstrip": false,
|
3425 |
+
"single_word": false,
|
3426 |
+
"special": true
|
3427 |
}
|
3428 |
},
|
3429 |
"additional_special_tokens": [
|
|
|
3849 |
"<im_end>",
|
3850 |
"<im_patch>",
|
3851 |
"<im_col>",
|
3852 |
+
"<|image|>",
|
3853 |
+
"<|proprio|>",
|
3854 |
+
"<|skill|>"
|
3855 |
],
|
3856 |
"auto_map": {
|
3857 |
"AutoProcessor": "preprocessing_molmo.MolmoProcessor"
|