Upload modeling_molmo.py with huggingface_hub
Browse files- modeling_molmo.py +20 -273
modeling_molmo.py
CHANGED
@@ -123,7 +123,7 @@ class RotaryEmbedding(nn.Module):
|
|
123 |
inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
124 |
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
125 |
freqs = torch.einsum("i , j -> i j", seq, inv_freq)
|
126 |
-
if self.config.rope_impl == "
|
127 |
positions = freqs.repeat_interleave(2, dim=-1)
|
128 |
else:
|
129 |
positions = torch.cat((freqs, freqs), dim=-1)
|
@@ -146,7 +146,7 @@ class RotaryEmbedding(nn.Module):
|
|
146 |
return x.view(B, nh, T, hs)
|
147 |
|
148 |
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
149 |
-
if self.config.rope_impl == "
|
150 |
return ((t * pos_cos) + (self.rotate_every_two(t) * pos_sin)).to(t.dtype)
|
151 |
else:
|
152 |
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
@@ -205,7 +205,7 @@ class MolmoBlock(nn.Module):
|
|
205 |
self._activation_checkpoint_fn = None
|
206 |
|
207 |
# Dropout.
|
208 |
-
self.dropout = Dropout(config.residual_dropout
|
209 |
|
210 |
# Layer norms.
|
211 |
self.k_norm: Optional[LayerNormBase] = None
|
@@ -298,7 +298,6 @@ class MolmoBlock(nn.Module):
|
|
298 |
k: torch.Tensor,
|
299 |
v: torch.Tensor,
|
300 |
attn_mask: Optional[torch.Tensor] = None,
|
301 |
-
drop_mask: Optional[torch.Tensor] = None,
|
302 |
dropout_p: float = 0.0,
|
303 |
response_dropout_p: float = 0.0,
|
304 |
is_causal: bool = False,
|
@@ -341,7 +340,6 @@ class MolmoBlock(nn.Module):
|
|
341 |
v: torch.Tensor,
|
342 |
attention_bias: Optional[torch.Tensor] = None,
|
343 |
position_ids: Optional[torch.Tensor] = None,
|
344 |
-
drop_mask: Optional[torch.Tensor] = None,
|
345 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
346 |
use_cache: bool = False,
|
347 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
@@ -394,7 +392,6 @@ class MolmoBlock(nn.Module):
|
|
394 |
k,
|
395 |
v,
|
396 |
attn_mask=attention_bias,
|
397 |
-
drop_mask=drop_mask,
|
398 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
399 |
response_dropout_p=0.0 if not self.training else self.config.response_attention_dropout,
|
400 |
is_causal=attention_bias is None,
|
@@ -411,7 +408,6 @@ class MolmoBlock(nn.Module):
|
|
411 |
x: torch.Tensor,
|
412 |
attention_bias: Optional[torch.FloatTensor] = None,
|
413 |
position_ids: Optional[torch.Tensor] = None,
|
414 |
-
drop_mask: Optional[torch.Tensor] = None,
|
415 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
416 |
use_cache: bool = False,
|
417 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
@@ -419,183 +415,7 @@ class MolmoBlock(nn.Module):
|
|
419 |
|
420 |
@classmethod
|
421 |
def build(cls, layer_id: int, config: MolmoConfig, cache: BufferCache):
|
422 |
-
|
423 |
-
return MolmoSequentialBlock(layer_id, config, cache)
|
424 |
-
elif config.block_type == "llama":
|
425 |
-
return OLMoLlamaBlock(layer_id, config, cache)
|
426 |
-
else:
|
427 |
-
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
|
428 |
-
|
429 |
-
|
430 |
-
class OLMoLlamaBlock(MolmoBlock):
|
431 |
-
"""
|
432 |
-
This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
|
433 |
-
(plus another skip connection). This block is similar to `MolmoSequentialBlock`
|
434 |
-
but some operations have slightly different implementations to imitate the
|
435 |
-
behavior of Llama.
|
436 |
-
"""
|
437 |
-
|
438 |
-
def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
|
439 |
-
super().__init__(layer_id, config, cache)
|
440 |
-
# Layer norms.
|
441 |
-
self.attn_norm = LayerNorm.build(config)
|
442 |
-
self.ff_norm = LayerNorm.build(config)
|
443 |
-
self.__cache = cache
|
444 |
-
|
445 |
-
# Attention input projection. Projects x -> (q, k, v)
|
446 |
-
q_proj_out_dim = config.d_model
|
447 |
-
k_proj_out_dim = config.effective_n_kv_heads * (config.d_model // config.n_heads)
|
448 |
-
v_proj_out_dim = config.effective_n_kv_heads * (config.d_model // config.n_heads)
|
449 |
-
|
450 |
-
self.q_proj = nn.Linear(
|
451 |
-
config.d_model, q_proj_out_dim, bias=config.qkv_bias, device=config.init_device
|
452 |
-
)
|
453 |
-
self.k_proj = nn.Linear(
|
454 |
-
config.d_model, k_proj_out_dim, bias=config.qkv_bias, device=config.init_device
|
455 |
-
)
|
456 |
-
self.v_proj = nn.Linear(
|
457 |
-
config.d_model, v_proj_out_dim, bias=config.qkv_bias, device=config.init_device
|
458 |
-
)
|
459 |
-
|
460 |
-
# Feed-forward input projection.
|
461 |
-
self.ff_proj1 = nn.Linear(
|
462 |
-
config.d_model, self.hidden_size // 2, bias=False, device=config.init_device
|
463 |
-
)
|
464 |
-
self.ff_proj2 = nn.Linear(
|
465 |
-
config.d_model, self.hidden_size // 2, bias=False, device=config.init_device
|
466 |
-
)
|
467 |
-
if self.config.norm_after:
|
468 |
-
raise NotImplementedError()
|
469 |
-
|
470 |
-
def reset_parameters(self):
|
471 |
-
super().reset_parameters()
|
472 |
-
self.attn_norm.reset_parameters()
|
473 |
-
self.ff_norm.reset_parameters()
|
474 |
-
# NOTE: the standard deviation for these weights does not depend on the layer.
|
475 |
-
init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
|
476 |
-
init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
|
477 |
-
init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
|
478 |
-
init_weights(self.config, self.ff_proj1, d=self.config.d_model, layer_id=None)
|
479 |
-
init_weights(self.config, self.ff_proj2, d=self.config.d_model, layer_id=None)
|
480 |
-
|
481 |
-
def _scaled_dot_product_attention(
|
482 |
-
self,
|
483 |
-
q: torch.Tensor,
|
484 |
-
k: torch.Tensor,
|
485 |
-
v: torch.Tensor,
|
486 |
-
attn_mask: Optional[torch.Tensor] = None,
|
487 |
-
drop_mask: Optional[torch.Tensor] = None,
|
488 |
-
dropout_p: float = 0.0,
|
489 |
-
response_dropout_p: float = 0.0,
|
490 |
-
is_causal: bool = False,
|
491 |
-
) -> torch.Tensor:
|
492 |
-
# For GQA
|
493 |
-
assert k.size(1) == v.size(1)
|
494 |
-
num_kv_heads = k.size(1)
|
495 |
-
num_q_heads = q.size(1)
|
496 |
-
if num_q_heads != num_kv_heads:
|
497 |
-
assert num_q_heads % num_kv_heads == 0
|
498 |
-
k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
499 |
-
v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
|
500 |
-
|
501 |
-
og_dtype = q.dtype
|
502 |
-
k = k.to(q.device)
|
503 |
-
v = v.to(q.device)
|
504 |
-
if attn_mask is not None:
|
505 |
-
attn_mask = attn_mask.to(q.device)
|
506 |
-
|
507 |
-
assert response_dropout_p == 0.0, "Response dropout is not supported in Llama."
|
508 |
-
|
509 |
-
if self.config.float32_attention:
|
510 |
-
q, k = q.to(torch.float), k.to(torch.float)
|
511 |
-
|
512 |
-
if self.config.attention_type == "direct":
|
513 |
-
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)
|
514 |
-
|
515 |
-
if is_causal:
|
516 |
-
assert attn_mask is None
|
517 |
-
|
518 |
-
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
|
519 |
-
attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
|
520 |
-
elif attn_mask is not None:
|
521 |
-
attn_bias = attn_mask
|
522 |
-
else:
|
523 |
-
attn_bias = torch.zeros_like(attn_weights)
|
524 |
-
|
525 |
-
attn_weights += attn_bias
|
526 |
-
|
527 |
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
528 |
-
attn_weights = nn.functional.dropout(attn_weights, p=dropout_p, training=self.training).to(v.dtype)
|
529 |
-
|
530 |
-
att = torch.matmul(attn_weights, v)
|
531 |
-
elif self.config.attention_type == "sdpa":
|
532 |
-
att = F.scaled_dot_product_attention(
|
533 |
-
q,
|
534 |
-
k,
|
535 |
-
v,
|
536 |
-
attn_mask=attn_mask,
|
537 |
-
dropout_p=dropout_p,
|
538 |
-
is_causal=is_causal,
|
539 |
-
)
|
540 |
-
else:
|
541 |
-
raise NotImplementedError(self.config.attention_type)
|
542 |
-
att = att.to(og_dtype)
|
543 |
-
return att
|
544 |
-
|
545 |
-
def forward(
|
546 |
-
self,
|
547 |
-
x: torch.Tensor,
|
548 |
-
attention_bias: Optional[torch.Tensor] = None,
|
549 |
-
position_ids: Optional[torch.Tensor] = None,
|
550 |
-
drop_mask: Optional[torch.Tensor] = None,
|
551 |
-
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
552 |
-
use_cache: bool = False,
|
553 |
-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
554 |
-
# Get query, key, value projections.
|
555 |
-
# shape:
|
556 |
-
# - for regular attn q, k, v: (batch_size, seq_len, d_model)
|
557 |
-
# - for multi-query attn q: (batch_size, seq_len, d_model)
|
558 |
-
# k, v: (batch_size, seq_len, d_model // n_heads)
|
559 |
-
x_normed = self.attn_norm(x)
|
560 |
-
q = self.q_proj(x_normed)
|
561 |
-
k = self.k_proj(x_normed)
|
562 |
-
v = self.v_proj(x_normed)
|
563 |
-
|
564 |
-
if self.config.clip_qkv is not None:
|
565 |
-
q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
566 |
-
k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
567 |
-
v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
|
568 |
-
|
569 |
-
# Get attention scores.
|
570 |
-
if self._activation_checkpoint_fn is not None:
|
571 |
-
att, cache = self._activation_checkpoint_fn( # type: ignore
|
572 |
-
self.attention, q, k, v, attention_bias, position_ids=position_ids, drop_mask=drop_mask, layer_past=layer_past, use_cache=use_cache
|
573 |
-
)
|
574 |
-
else:
|
575 |
-
att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, drop_mask=drop_mask, layer_past=layer_past, use_cache=use_cache)
|
576 |
-
|
577 |
-
# Add attention scores.
|
578 |
-
# shape: (B, T, C)
|
579 |
-
x = x + self.dropout(att, drop_mask=drop_mask)
|
580 |
-
|
581 |
-
# Add feed-forward projection.
|
582 |
-
# shape: (batch_size, seq_len, d_model)
|
583 |
-
og_x = x
|
584 |
-
if self._activation_checkpoint_fn is not None:
|
585 |
-
x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
|
586 |
-
else:
|
587 |
-
x = self.ff_norm(x)
|
588 |
-
x1 = self.ff_proj1(x)
|
589 |
-
x2 = self.ff_proj2(x)
|
590 |
-
if self._activation_checkpoint_fn is not None:
|
591 |
-
x = self._activation_checkpoint_fn(self.act, x1, x2) # type: ignore
|
592 |
-
else:
|
593 |
-
x = self.act(x1, x2)
|
594 |
-
x = self.ff_out(x)
|
595 |
-
x = self.dropout(x, drop_mask=drop_mask)
|
596 |
-
x = og_x + x
|
597 |
-
|
598 |
-
return x, cache
|
599 |
|
600 |
|
601 |
class MolmoSequentialBlock(MolmoBlock):
|
@@ -644,7 +464,6 @@ class MolmoSequentialBlock(MolmoBlock):
|
|
644 |
x: torch.Tensor,
|
645 |
attention_bias: Optional[torch.Tensor] = None,
|
646 |
position_ids: Optional[torch.Tensor] = None,
|
647 |
-
drop_mask: Optional[torch.Tensor] = None,
|
648 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
649 |
use_cache: bool = False,
|
650 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
@@ -673,10 +492,10 @@ class MolmoSequentialBlock(MolmoBlock):
|
|
673 |
# Get attention scores.
|
674 |
if self._activation_checkpoint_fn is not None:
|
675 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
676 |
-
self.attention, q, k, v, attention_bias, position_ids=position_ids,
|
677 |
)
|
678 |
else:
|
679 |
-
att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids,
|
680 |
|
681 |
if self.config.norm_after:
|
682 |
if self._activation_checkpoint_fn is not None:
|
@@ -686,7 +505,7 @@ class MolmoSequentialBlock(MolmoBlock):
|
|
686 |
|
687 |
# Add attention scores.
|
688 |
# shape: (B, T, C)
|
689 |
-
x = x + self.dropout(att
|
690 |
|
691 |
# Add feed-forward projection.
|
692 |
# shape: (batch_size, seq_len, d_model)
|
@@ -711,7 +530,7 @@ class MolmoSequentialBlock(MolmoBlock):
|
|
711 |
else:
|
712 |
x = self.ff_norm(x)
|
713 |
|
714 |
-
x = self.dropout(x
|
715 |
x = og_x + x
|
716 |
|
717 |
return x, cache
|
@@ -757,27 +576,14 @@ class Dropout(nn.Dropout):
|
|
757 |
self.mask_p = mask_p
|
758 |
self.broadcast_dims = broadcast_dims
|
759 |
|
760 |
-
def forward(self, input: torch.Tensor
|
761 |
"""
|
762 |
:param input: A tensor of shape `(batch_size, seq_len, embed_dim)`
|
763 |
-
:param drop_mask: A tensor of shape `(batch_size, seq_len)` with values of zero or one.
|
764 |
"""
|
765 |
if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
|
766 |
return input
|
767 |
else:
|
768 |
-
if self.
|
769 |
-
assert drop_mask is not None
|
770 |
-
drop_mask = drop_mask.to(input.dtype)
|
771 |
-
keep_prob = 1.0 - self.p
|
772 |
-
keep_prob2 = 1.0 - self.mask_p
|
773 |
-
keep_prob = drop_mask * keep_prob2 + (1 - drop_mask) * keep_prob
|
774 |
-
keep_prob = keep_prob.unsqueeze(-1)
|
775 |
-
dropout_shape = list(input.shape)
|
776 |
-
keep_prob = keep_prob.broadcast_to(dropout_shape)
|
777 |
-
multiplier = input.new_empty(dropout_shape).bernoulli_(keep_prob)
|
778 |
-
multiplier.div_(keep_prob)
|
779 |
-
return input * multiplier
|
780 |
-
elif self.p > 0. and len(self.broadcast_dims) > 0 and self.training:
|
781 |
keep_prob = 1.0 - self.p
|
782 |
dropout_shape = list(input.shape)
|
783 |
for dim in self.broadcast_dims:
|
@@ -792,7 +598,6 @@ class Dropout(nn.Dropout):
|
|
792 |
|
793 |
@dataclass
|
794 |
class VisionBackboneConfig:
|
795 |
-
image_model_type: str = "openai"
|
796 |
image_default_input_size: Tuple[int, int] = (336, 336)
|
797 |
image_patch_size: int = 14
|
798 |
image_pos_patch_size: int = 14
|
@@ -832,17 +637,12 @@ class FullMolmoConfig:
|
|
832 |
mlp_ratio: int = 4
|
833 |
mlp_hidden_size: Optional[int] = None
|
834 |
activation_type: str = "swiglu"
|
835 |
-
block_type: str = "sequential"
|
836 |
block_group_size: int = 1
|
837 |
-
|
838 |
-
alibi_bias_max: float = 8.0
|
839 |
-
rope: bool = False
|
840 |
rope_full_precision: bool = True
|
841 |
rope_theta: float = 10000.
|
842 |
-
rope_impl: str = "
|
843 |
vision_backbone: Optional[VisionBackboneConfig] = None
|
844 |
-
vit_load_path: Optional[str] = None
|
845 |
-
llm_load_path: Optional[str] = None
|
846 |
attention_type: str = "sdpa"
|
847 |
float32_attention: bool = True
|
848 |
attention_dropout: float = 0.1
|
@@ -850,7 +650,6 @@ class FullMolmoConfig:
|
|
850 |
multi_query_attention: Optional[bool] = None
|
851 |
attention_layer_norm: bool = False
|
852 |
residual_dropout: float = 0.1
|
853 |
-
response_residual_dropout: float = 0.0
|
854 |
embedding_dropout: float = 0.1
|
855 |
layer_norm_type: str = "default"
|
856 |
layer_norm_with_affine: bool = True
|
@@ -872,10 +671,6 @@ class FullMolmoConfig:
|
|
872 |
init_cutoff_factor: Optional[float] = None
|
873 |
norm_after: bool = False
|
874 |
precision: Optional[str] = None
|
875 |
-
max_crops: int = 12
|
876 |
-
crop_mode: str = "patchify-v2-and-resize-c2"
|
877 |
-
do_random_scale: bool = True
|
878 |
-
use_col_tokens: bool = True
|
879 |
image_padding_embed: Optional[str] = None
|
880 |
vit_layers: Tuple = (-1,)
|
881 |
image_pooling_h: int = 2
|
@@ -883,12 +678,9 @@ class FullMolmoConfig:
|
|
883 |
image_pooling_2d: str = "attention"
|
884 |
image_projector: str = "mlp"
|
885 |
image_feature_dropout: float = 0.0
|
886 |
-
use_cls_feature: bool = False
|
887 |
initializer_range: float = 0.02
|
888 |
-
pad_tokenizer: bool = False
|
889 |
normalize_input_embeds: bool = False
|
890 |
use_position_ids: bool = True
|
891 |
-
query_pre_attn_scalar: int = 224
|
892 |
|
893 |
@property
|
894 |
def effective_n_kv_heads(self) -> int:
|
@@ -1112,7 +904,7 @@ class VisionTransformer(nn.Module):
|
|
1112 |
if patch_num is None:
|
1113 |
patch_num = self.config.vision_backbone.image_num_patch
|
1114 |
B, N, D = x.shape
|
1115 |
-
|
1116 |
x = self.patch_embedding(x)
|
1117 |
|
1118 |
# class embeddings and positional embeddings
|
@@ -1526,15 +1318,6 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|
1526 |
|
1527 |
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
1528 |
assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
|
1529 |
-
if config.use_cls_feature:
|
1530 |
-
assert self.num_prefix_tokens > 0, "The model does not have a CLS token"
|
1531 |
-
nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
|
1532 |
-
self.cls_projector = nn.Linear(
|
1533 |
-
nlayers * v_cfg.image_emb_dim,
|
1534 |
-
self.input_dim,
|
1535 |
-
bias=False,
|
1536 |
-
device=config.init_device,
|
1537 |
-
)
|
1538 |
|
1539 |
self.pad_embed = None
|
1540 |
if config.image_padding_embed:
|
@@ -1551,8 +1334,6 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|
1551 |
def reset_parameters(self):
|
1552 |
super().reset_parameters()
|
1553 |
self.image_vit.reset_parameters()
|
1554 |
-
if self.config.use_cls_feature:
|
1555 |
-
nn.init.xavier_uniform_(self.cls_projector.weight)
|
1556 |
|
1557 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
1558 |
"""
|
@@ -1562,7 +1343,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|
1562 |
v_cfg = self.config.vision_backbone
|
1563 |
B, T, N, D = images.shape
|
1564 |
|
1565 |
-
mask = torch.all(images.view(B * T, N, D)
|
1566 |
|
1567 |
# Output all hidden states
|
1568 |
# n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim)
|
@@ -1658,9 +1439,6 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
|
|
1658 |
else:
|
1659 |
image_features = self.image_projector(image_features)
|
1660 |
|
1661 |
-
if self.config.use_cls_feature:
|
1662 |
-
raise NotImplementedError()
|
1663 |
-
|
1664 |
# image_features: (batch_size, num_image, num_patch, d_model)
|
1665 |
# cls_embed: (batch_size, num_image, d_model)
|
1666 |
return image_features, cls_embed
|
@@ -1944,7 +1722,7 @@ class Molmo(nn.Module):
|
|
1944 |
else:
|
1945 |
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
1946 |
|
1947 |
-
if not
|
1948 |
self.transformer.update(
|
1949 |
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
1950 |
)
|
@@ -2105,23 +1883,7 @@ class Molmo(nn.Module):
|
|
2105 |
|
2106 |
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
|
2107 |
|
2108 |
-
|
2109 |
-
x = torch.cat([x[:, :1], cls_embed, x[:, 1:-num_image]], dim=1)
|
2110 |
-
|
2111 |
-
valid_images = torch.any(
|
2112 |
-
(image_input_idx >= 0).view(batch_size, num_image, num_patch), dim=-1
|
2113 |
-
)
|
2114 |
-
valid_images = valid_images.to(attention_mask.dtype)
|
2115 |
-
attention_mask = torch.cat(
|
2116 |
-
[attention_mask[:, :1], valid_images, attention_mask[:, 1:-num_image]],
|
2117 |
-
dim=1,
|
2118 |
-
)
|
2119 |
-
position_ids = torch.clamp(
|
2120 |
-
torch.cumsum(attention_mask, dim=-1) - 1,
|
2121 |
-
min=0,
|
2122 |
-
).broadcast_to((batch_size, attention_mask.shape[-1]))
|
2123 |
-
|
2124 |
-
if not (self.config.alibi or self.config.rope):
|
2125 |
# Get positional embeddings.
|
2126 |
# shape: (1, seq_len)
|
2127 |
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
@@ -2151,17 +1913,12 @@ class Molmo(nn.Module):
|
|
2151 |
if (
|
2152 |
attention_bias is not None
|
2153 |
or attention_mask is not None
|
2154 |
-
or self.config.alibi
|
2155 |
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
2156 |
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
2157 |
# scores correctly.
|
2158 |
or past_key_values is not None
|
2159 |
):
|
2160 |
-
if attention_bias is None
|
2161 |
-
attention_bias = get_causal_attention_bias(
|
2162 |
-
self.__cache, past_length + seq_len, x.device
|
2163 |
-
) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
|
2164 |
-
elif attention_bias is None:
|
2165 |
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
2166 |
elif attention_bias.dtype in (torch.int8, torch.bool):
|
2167 |
attention_bias = attention_bias.to(dtype=torch.float)
|
@@ -2196,7 +1953,7 @@ class Molmo(nn.Module):
|
|
2196 |
all_hidden_states.append(x)
|
2197 |
|
2198 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
2199 |
-
x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids,
|
2200 |
|
2201 |
if attn_key_values is not None:
|
2202 |
assert cache is not None
|
@@ -2215,19 +1972,12 @@ class Molmo(nn.Module):
|
|
2215 |
]
|
2216 |
)
|
2217 |
x, cache = block_group(
|
2218 |
-
x, attention_bias=attention_bias, position_ids=position_ids,
|
2219 |
)
|
2220 |
if attn_key_values is not None:
|
2221 |
assert cache is not None
|
2222 |
attn_key_values.extend(cache)
|
2223 |
|
2224 |
-
if images is not None and self.config.use_cls_feature:
|
2225 |
-
assert num_image is not None
|
2226 |
-
x = torch.cat(
|
2227 |
-
[x[:, :1], x[:, num_image+1:], torch.zeros_like(x[:, :num_image])],
|
2228 |
-
dim=1,
|
2229 |
-
)
|
2230 |
-
|
2231 |
if last_logits_only:
|
2232 |
# shape: (batch_size, 1, d_model)
|
2233 |
if append_last_valid_logits is not None:
|
@@ -2271,9 +2021,9 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2271 |
|
2272 |
if not model:
|
2273 |
full_config = FullMolmoConfig(
|
2274 |
-
attention_layer_norm=config.attention_layer_norm,
|
2275 |
image_padding_embed="pad_and_partial_pad",
|
2276 |
image_pooling_2d="attention-meanq",
|
|
|
2277 |
rope_impl="llama",
|
2278 |
vocab_size=config.vocab_size,
|
2279 |
max_sequence_length=config.max_position_embeddings,
|
@@ -2282,7 +2032,6 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2282 |
embedding_size=config.embedding_size,
|
2283 |
attention_type="sdpa",
|
2284 |
embedding_dropout=0,
|
2285 |
-
response_residual_dropout=0,
|
2286 |
attention_dropout=0,
|
2287 |
residual_dropout=0,
|
2288 |
rope=True,
|
@@ -2297,10 +2046,8 @@ class MolmoForCausalLM(PreTrainedModel):
|
|
2297 |
rope_theta=config.rope_theta,
|
2298 |
layer_norm_eps=config.layer_norm_eps,
|
2299 |
layer_norm_type=config.layer_norm_type,
|
2300 |
-
pad_tokenizer=True,
|
2301 |
vit_layers=[-2, -9],
|
2302 |
vision_backbone=VisionBackboneConfig(
|
2303 |
-
image_model_type="openai",
|
2304 |
image_default_input_size=(336, 336),
|
2305 |
image_patch_size=14,
|
2306 |
image_pos_patch_size=14,
|
|
|
123 |
inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
|
124 |
seq = torch.arange(seq_len, device=device, dtype=torch.float)
|
125 |
freqs = torch.einsum("i , j -> i j", seq, inv_freq)
|
126 |
+
if self.config.rope_impl == "interleave":
|
127 |
positions = freqs.repeat_interleave(2, dim=-1)
|
128 |
else:
|
129 |
positions = torch.cat((freqs, freqs), dim=-1)
|
|
|
146 |
return x.view(B, nh, T, hs)
|
147 |
|
148 |
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
149 |
+
if self.config.rope_impl == "interleave":
|
150 |
return ((t * pos_cos) + (self.rotate_every_two(t) * pos_sin)).to(t.dtype)
|
151 |
else:
|
152 |
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
|
|
|
205 |
self._activation_checkpoint_fn = None
|
206 |
|
207 |
# Dropout.
|
208 |
+
self.dropout = Dropout(config.residual_dropout)
|
209 |
|
210 |
# Layer norms.
|
211 |
self.k_norm: Optional[LayerNormBase] = None
|
|
|
298 |
k: torch.Tensor,
|
299 |
v: torch.Tensor,
|
300 |
attn_mask: Optional[torch.Tensor] = None,
|
|
|
301 |
dropout_p: float = 0.0,
|
302 |
response_dropout_p: float = 0.0,
|
303 |
is_causal: bool = False,
|
|
|
340 |
v: torch.Tensor,
|
341 |
attention_bias: Optional[torch.Tensor] = None,
|
342 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
343 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
344 |
use_cache: bool = False,
|
345 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
392 |
k,
|
393 |
v,
|
394 |
attn_mask=attention_bias,
|
|
|
395 |
dropout_p=0.0 if not self.training else self.config.attention_dropout,
|
396 |
response_dropout_p=0.0 if not self.training else self.config.response_attention_dropout,
|
397 |
is_causal=attention_bias is None,
|
|
|
408 |
x: torch.Tensor,
|
409 |
attention_bias: Optional[torch.FloatTensor] = None,
|
410 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
411 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
412 |
use_cache: bool = False,
|
413 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
415 |
|
416 |
@classmethod
|
417 |
def build(cls, layer_id: int, config: MolmoConfig, cache: BufferCache):
|
418 |
+
return MolmoSequentialBlock(layer_id, config, cache)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
419 |
|
420 |
|
421 |
class MolmoSequentialBlock(MolmoBlock):
|
|
|
464 |
x: torch.Tensor,
|
465 |
attention_bias: Optional[torch.Tensor] = None,
|
466 |
position_ids: Optional[torch.Tensor] = None,
|
|
|
467 |
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
468 |
use_cache: bool = False,
|
469 |
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
|
|
492 |
# Get attention scores.
|
493 |
if self._activation_checkpoint_fn is not None:
|
494 |
att, cache = self._activation_checkpoint_fn( # type: ignore
|
495 |
+
self.attention, q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
|
496 |
)
|
497 |
else:
|
498 |
+
att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
|
499 |
|
500 |
if self.config.norm_after:
|
501 |
if self._activation_checkpoint_fn is not None:
|
|
|
505 |
|
506 |
# Add attention scores.
|
507 |
# shape: (B, T, C)
|
508 |
+
x = x + self.dropout(att)
|
509 |
|
510 |
# Add feed-forward projection.
|
511 |
# shape: (batch_size, seq_len, d_model)
|
|
|
530 |
else:
|
531 |
x = self.ff_norm(x)
|
532 |
|
533 |
+
x = self.dropout(x)
|
534 |
x = og_x + x
|
535 |
|
536 |
return x, cache
|
|
|
576 |
self.mask_p = mask_p
|
577 |
self.broadcast_dims = broadcast_dims
|
578 |
|
579 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
580 |
"""
|
581 |
:param input: A tensor of shape `(batch_size, seq_len, embed_dim)`
|
|
|
582 |
"""
|
583 |
if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
|
584 |
return input
|
585 |
else:
|
586 |
+
if self.p > 0. and len(self.broadcast_dims) > 0 and self.training:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
keep_prob = 1.0 - self.p
|
588 |
dropout_shape = list(input.shape)
|
589 |
for dim in self.broadcast_dims:
|
|
|
598 |
|
599 |
@dataclass
|
600 |
class VisionBackboneConfig:
|
|
|
601 |
image_default_input_size: Tuple[int, int] = (336, 336)
|
602 |
image_patch_size: int = 14
|
603 |
image_pos_patch_size: int = 14
|
|
|
637 |
mlp_ratio: int = 4
|
638 |
mlp_hidden_size: Optional[int] = None
|
639 |
activation_type: str = "swiglu"
|
|
|
640 |
block_group_size: int = 1
|
641 |
+
rope: bool = True
|
|
|
|
|
642 |
rope_full_precision: bool = True
|
643 |
rope_theta: float = 10000.
|
644 |
+
rope_impl: str = "interleave"
|
645 |
vision_backbone: Optional[VisionBackboneConfig] = None
|
|
|
|
|
646 |
attention_type: str = "sdpa"
|
647 |
float32_attention: bool = True
|
648 |
attention_dropout: float = 0.1
|
|
|
650 |
multi_query_attention: Optional[bool] = None
|
651 |
attention_layer_norm: bool = False
|
652 |
residual_dropout: float = 0.1
|
|
|
653 |
embedding_dropout: float = 0.1
|
654 |
layer_norm_type: str = "default"
|
655 |
layer_norm_with_affine: bool = True
|
|
|
671 |
init_cutoff_factor: Optional[float] = None
|
672 |
norm_after: bool = False
|
673 |
precision: Optional[str] = None
|
|
|
|
|
|
|
|
|
674 |
image_padding_embed: Optional[str] = None
|
675 |
vit_layers: Tuple = (-1,)
|
676 |
image_pooling_h: int = 2
|
|
|
678 |
image_pooling_2d: str = "attention"
|
679 |
image_projector: str = "mlp"
|
680 |
image_feature_dropout: float = 0.0
|
|
|
681 |
initializer_range: float = 0.02
|
|
|
682 |
normalize_input_embeds: bool = False
|
683 |
use_position_ids: bool = True
|
|
|
684 |
|
685 |
@property
|
686 |
def effective_n_kv_heads(self) -> int:
|
|
|
904 |
if patch_num is None:
|
905 |
patch_num = self.config.vision_backbone.image_num_patch
|
906 |
B, N, D = x.shape
|
907 |
+
|
908 |
x = self.patch_embedding(x)
|
909 |
|
910 |
# class embeddings and positional embeddings
|
|
|
1318 |
|
1319 |
self.num_prefix_tokens = self.image_vit.num_prefix_tokens
|
1320 |
assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1321 |
|
1322 |
self.pad_embed = None
|
1323 |
if config.image_padding_embed:
|
|
|
1334 |
def reset_parameters(self):
|
1335 |
super().reset_parameters()
|
1336 |
self.image_vit.reset_parameters()
|
|
|
|
|
1337 |
|
1338 |
def encode_image(self, images: torch.Tensor) -> torch.Tensor:
|
1339 |
"""
|
|
|
1343 |
v_cfg = self.config.vision_backbone
|
1344 |
B, T, N, D = images.shape
|
1345 |
|
1346 |
+
mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
|
1347 |
|
1348 |
# Output all hidden states
|
1349 |
# n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim)
|
|
|
1439 |
else:
|
1440 |
image_features = self.image_projector(image_features)
|
1441 |
|
|
|
|
|
|
|
1442 |
# image_features: (batch_size, num_image, num_patch, d_model)
|
1443 |
# cls_embed: (batch_size, num_image, d_model)
|
1444 |
return image_features, cls_embed
|
|
|
1722 |
else:
|
1723 |
self.transformer.update({"blocks": nn.ModuleList(blocks)})
|
1724 |
|
1725 |
+
if not self.config.rope:
|
1726 |
self.transformer.update(
|
1727 |
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
|
1728 |
)
|
|
|
1883 |
|
1884 |
x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
|
1885 |
|
1886 |
+
if not self.config.rope:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1887 |
# Get positional embeddings.
|
1888 |
# shape: (1, seq_len)
|
1889 |
pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
|
|
|
1913 |
if (
|
1914 |
attention_bias is not None
|
1915 |
or attention_mask is not None
|
|
|
1916 |
# NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
|
1917 |
# with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
|
1918 |
# scores correctly.
|
1919 |
or past_key_values is not None
|
1920 |
):
|
1921 |
+
if attention_bias is None:
|
|
|
|
|
|
|
|
|
1922 |
attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
|
1923 |
elif attention_bias.dtype in (torch.int8, torch.bool):
|
1924 |
attention_bias = attention_bias.to(dtype=torch.float)
|
|
|
1953 |
all_hidden_states.append(x)
|
1954 |
|
1955 |
layer_past = None if past_key_values is None else past_key_values[block_idx]
|
1956 |
+
x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
|
1957 |
|
1958 |
if attn_key_values is not None:
|
1959 |
assert cache is not None
|
|
|
1972 |
]
|
1973 |
)
|
1974 |
x, cache = block_group(
|
1975 |
+
x, attention_bias=attention_bias, position_ids=position_ids, layers_past=layers_past, use_cache=use_cache
|
1976 |
)
|
1977 |
if attn_key_values is not None:
|
1978 |
assert cache is not None
|
1979 |
attn_key_values.extend(cache)
|
1980 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1981 |
if last_logits_only:
|
1982 |
# shape: (batch_size, 1, d_model)
|
1983 |
if append_last_valid_logits is not None:
|
|
|
2021 |
|
2022 |
if not model:
|
2023 |
full_config = FullMolmoConfig(
|
|
|
2024 |
image_padding_embed="pad_and_partial_pad",
|
2025 |
image_pooling_2d="attention-meanq",
|
2026 |
+
attention_layer_norm=config.attention_layer_norm,
|
2027 |
rope_impl="llama",
|
2028 |
vocab_size=config.vocab_size,
|
2029 |
max_sequence_length=config.max_position_embeddings,
|
|
|
2032 |
embedding_size=config.embedding_size,
|
2033 |
attention_type="sdpa",
|
2034 |
embedding_dropout=0,
|
|
|
2035 |
attention_dropout=0,
|
2036 |
residual_dropout=0,
|
2037 |
rope=True,
|
|
|
2046 |
rope_theta=config.rope_theta,
|
2047 |
layer_norm_eps=config.layer_norm_eps,
|
2048 |
layer_norm_type=config.layer_norm_type,
|
|
|
2049 |
vit_layers=[-2, -9],
|
2050 |
vision_backbone=VisionBackboneConfig(
|
|
|
2051 |
image_default_input_size=(336, 336),
|
2052 |
image_patch_size=14,
|
2053 |
image_pos_patch_size=14,
|