chrisc36 commited on
Commit
d383de4
1 Parent(s): b72f674

Upload modeling_molmo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_molmo.py +20 -273
modeling_molmo.py CHANGED
@@ -123,7 +123,7 @@ class RotaryEmbedding(nn.Module):
123
  inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
124
  seq = torch.arange(seq_len, device=device, dtype=torch.float)
125
  freqs = torch.einsum("i , j -> i j", seq, inv_freq)
126
- if self.config.rope_impl == "cockatoo":
127
  positions = freqs.repeat_interleave(2, dim=-1)
128
  else:
129
  positions = torch.cat((freqs, freqs), dim=-1)
@@ -146,7 +146,7 @@ class RotaryEmbedding(nn.Module):
146
  return x.view(B, nh, T, hs)
147
 
148
  def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
149
- if self.config.rope_impl == "cockatoo":
150
  return ((t * pos_cos) + (self.rotate_every_two(t) * pos_sin)).to(t.dtype)
151
  else:
152
  return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
@@ -205,7 +205,7 @@ class MolmoBlock(nn.Module):
205
  self._activation_checkpoint_fn = None
206
 
207
  # Dropout.
208
- self.dropout = Dropout(config.residual_dropout, mask_p=config.response_residual_dropout)
209
 
210
  # Layer norms.
211
  self.k_norm: Optional[LayerNormBase] = None
@@ -298,7 +298,6 @@ class MolmoBlock(nn.Module):
298
  k: torch.Tensor,
299
  v: torch.Tensor,
300
  attn_mask: Optional[torch.Tensor] = None,
301
- drop_mask: Optional[torch.Tensor] = None,
302
  dropout_p: float = 0.0,
303
  response_dropout_p: float = 0.0,
304
  is_causal: bool = False,
@@ -341,7 +340,6 @@ class MolmoBlock(nn.Module):
341
  v: torch.Tensor,
342
  attention_bias: Optional[torch.Tensor] = None,
343
  position_ids: Optional[torch.Tensor] = None,
344
- drop_mask: Optional[torch.Tensor] = None,
345
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
346
  use_cache: bool = False,
347
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
@@ -394,7 +392,6 @@ class MolmoBlock(nn.Module):
394
  k,
395
  v,
396
  attn_mask=attention_bias,
397
- drop_mask=drop_mask,
398
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
399
  response_dropout_p=0.0 if not self.training else self.config.response_attention_dropout,
400
  is_causal=attention_bias is None,
@@ -411,7 +408,6 @@ class MolmoBlock(nn.Module):
411
  x: torch.Tensor,
412
  attention_bias: Optional[torch.FloatTensor] = None,
413
  position_ids: Optional[torch.Tensor] = None,
414
- drop_mask: Optional[torch.Tensor] = None,
415
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
416
  use_cache: bool = False,
417
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
@@ -419,183 +415,7 @@ class MolmoBlock(nn.Module):
419
 
420
  @classmethod
421
  def build(cls, layer_id: int, config: MolmoConfig, cache: BufferCache):
422
- if config.block_type == "sequential":
423
- return MolmoSequentialBlock(layer_id, config, cache)
424
- elif config.block_type == "llama":
425
- return OLMoLlamaBlock(layer_id, config, cache)
426
- else:
427
- raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
428
-
429
-
430
- class OLMoLlamaBlock(MolmoBlock):
431
- """
432
- This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
433
- (plus another skip connection). This block is similar to `MolmoSequentialBlock`
434
- but some operations have slightly different implementations to imitate the
435
- behavior of Llama.
436
- """
437
-
438
- def __init__(self, layer_id: int, config: MolmoConfig, cache: BufferCache):
439
- super().__init__(layer_id, config, cache)
440
- # Layer norms.
441
- self.attn_norm = LayerNorm.build(config)
442
- self.ff_norm = LayerNorm.build(config)
443
- self.__cache = cache
444
-
445
- # Attention input projection. Projects x -> (q, k, v)
446
- q_proj_out_dim = config.d_model
447
- k_proj_out_dim = config.effective_n_kv_heads * (config.d_model // config.n_heads)
448
- v_proj_out_dim = config.effective_n_kv_heads * (config.d_model // config.n_heads)
449
-
450
- self.q_proj = nn.Linear(
451
- config.d_model, q_proj_out_dim, bias=config.qkv_bias, device=config.init_device
452
- )
453
- self.k_proj = nn.Linear(
454
- config.d_model, k_proj_out_dim, bias=config.qkv_bias, device=config.init_device
455
- )
456
- self.v_proj = nn.Linear(
457
- config.d_model, v_proj_out_dim, bias=config.qkv_bias, device=config.init_device
458
- )
459
-
460
- # Feed-forward input projection.
461
- self.ff_proj1 = nn.Linear(
462
- config.d_model, self.hidden_size // 2, bias=False, device=config.init_device
463
- )
464
- self.ff_proj2 = nn.Linear(
465
- config.d_model, self.hidden_size // 2, bias=False, device=config.init_device
466
- )
467
- if self.config.norm_after:
468
- raise NotImplementedError()
469
-
470
- def reset_parameters(self):
471
- super().reset_parameters()
472
- self.attn_norm.reset_parameters()
473
- self.ff_norm.reset_parameters()
474
- # NOTE: the standard deviation for these weights does not depend on the layer.
475
- init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
476
- init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
477
- init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
478
- init_weights(self.config, self.ff_proj1, d=self.config.d_model, layer_id=None)
479
- init_weights(self.config, self.ff_proj2, d=self.config.d_model, layer_id=None)
480
-
481
- def _scaled_dot_product_attention(
482
- self,
483
- q: torch.Tensor,
484
- k: torch.Tensor,
485
- v: torch.Tensor,
486
- attn_mask: Optional[torch.Tensor] = None,
487
- drop_mask: Optional[torch.Tensor] = None,
488
- dropout_p: float = 0.0,
489
- response_dropout_p: float = 0.0,
490
- is_causal: bool = False,
491
- ) -> torch.Tensor:
492
- # For GQA
493
- assert k.size(1) == v.size(1)
494
- num_kv_heads = k.size(1)
495
- num_q_heads = q.size(1)
496
- if num_q_heads != num_kv_heads:
497
- assert num_q_heads % num_kv_heads == 0
498
- k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
499
- v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
500
-
501
- og_dtype = q.dtype
502
- k = k.to(q.device)
503
- v = v.to(q.device)
504
- if attn_mask is not None:
505
- attn_mask = attn_mask.to(q.device)
506
-
507
- assert response_dropout_p == 0.0, "Response dropout is not supported in Llama."
508
-
509
- if self.config.float32_attention:
510
- q, k = q.to(torch.float), k.to(torch.float)
511
-
512
- if self.config.attention_type == "direct":
513
- attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (q.shape[-1] ** 0.5)
514
-
515
- if is_causal:
516
- assert attn_mask is None
517
-
518
- query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
519
- attn_bias = get_causal_attention_bias(self.__cache, key_len, q.device)[:, :, :query_len, :key_len]
520
- elif attn_mask is not None:
521
- attn_bias = attn_mask
522
- else:
523
- attn_bias = torch.zeros_like(attn_weights)
524
-
525
- attn_weights += attn_bias
526
-
527
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
528
- attn_weights = nn.functional.dropout(attn_weights, p=dropout_p, training=self.training).to(v.dtype)
529
-
530
- att = torch.matmul(attn_weights, v)
531
- elif self.config.attention_type == "sdpa":
532
- att = F.scaled_dot_product_attention(
533
- q,
534
- k,
535
- v,
536
- attn_mask=attn_mask,
537
- dropout_p=dropout_p,
538
- is_causal=is_causal,
539
- )
540
- else:
541
- raise NotImplementedError(self.config.attention_type)
542
- att = att.to(og_dtype)
543
- return att
544
-
545
- def forward(
546
- self,
547
- x: torch.Tensor,
548
- attention_bias: Optional[torch.Tensor] = None,
549
- position_ids: Optional[torch.Tensor] = None,
550
- drop_mask: Optional[torch.Tensor] = None,
551
- layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
552
- use_cache: bool = False,
553
- ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
554
- # Get query, key, value projections.
555
- # shape:
556
- # - for regular attn q, k, v: (batch_size, seq_len, d_model)
557
- # - for multi-query attn q: (batch_size, seq_len, d_model)
558
- # k, v: (batch_size, seq_len, d_model // n_heads)
559
- x_normed = self.attn_norm(x)
560
- q = self.q_proj(x_normed)
561
- k = self.k_proj(x_normed)
562
- v = self.v_proj(x_normed)
563
-
564
- if self.config.clip_qkv is not None:
565
- q.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
566
- k.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
567
- v.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
568
-
569
- # Get attention scores.
570
- if self._activation_checkpoint_fn is not None:
571
- att, cache = self._activation_checkpoint_fn( # type: ignore
572
- self.attention, q, k, v, attention_bias, position_ids=position_ids, drop_mask=drop_mask, layer_past=layer_past, use_cache=use_cache
573
- )
574
- else:
575
- att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, drop_mask=drop_mask, layer_past=layer_past, use_cache=use_cache)
576
-
577
- # Add attention scores.
578
- # shape: (B, T, C)
579
- x = x + self.dropout(att, drop_mask=drop_mask)
580
-
581
- # Add feed-forward projection.
582
- # shape: (batch_size, seq_len, d_model)
583
- og_x = x
584
- if self._activation_checkpoint_fn is not None:
585
- x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
586
- else:
587
- x = self.ff_norm(x)
588
- x1 = self.ff_proj1(x)
589
- x2 = self.ff_proj2(x)
590
- if self._activation_checkpoint_fn is not None:
591
- x = self._activation_checkpoint_fn(self.act, x1, x2) # type: ignore
592
- else:
593
- x = self.act(x1, x2)
594
- x = self.ff_out(x)
595
- x = self.dropout(x, drop_mask=drop_mask)
596
- x = og_x + x
597
-
598
- return x, cache
599
 
600
 
601
  class MolmoSequentialBlock(MolmoBlock):
@@ -644,7 +464,6 @@ class MolmoSequentialBlock(MolmoBlock):
644
  x: torch.Tensor,
645
  attention_bias: Optional[torch.Tensor] = None,
646
  position_ids: Optional[torch.Tensor] = None,
647
- drop_mask: Optional[torch.Tensor] = None,
648
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
649
  use_cache: bool = False,
650
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
@@ -673,10 +492,10 @@ class MolmoSequentialBlock(MolmoBlock):
673
  # Get attention scores.
674
  if self._activation_checkpoint_fn is not None:
675
  att, cache = self._activation_checkpoint_fn( # type: ignore
676
- self.attention, q, k, v, attention_bias, position_ids=position_ids, drop_mask=drop_mask, layer_past=layer_past, use_cache=use_cache
677
  )
678
  else:
679
- att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, drop_mask=drop_mask, layer_past=layer_past, use_cache=use_cache)
680
 
681
  if self.config.norm_after:
682
  if self._activation_checkpoint_fn is not None:
@@ -686,7 +505,7 @@ class MolmoSequentialBlock(MolmoBlock):
686
 
687
  # Add attention scores.
688
  # shape: (B, T, C)
689
- x = x + self.dropout(att, drop_mask=drop_mask)
690
 
691
  # Add feed-forward projection.
692
  # shape: (batch_size, seq_len, d_model)
@@ -711,7 +530,7 @@ class MolmoSequentialBlock(MolmoBlock):
711
  else:
712
  x = self.ff_norm(x)
713
 
714
- x = self.dropout(x, drop_mask=drop_mask)
715
  x = og_x + x
716
 
717
  return x, cache
@@ -757,27 +576,14 @@ class Dropout(nn.Dropout):
757
  self.mask_p = mask_p
758
  self.broadcast_dims = broadcast_dims
759
 
760
- def forward(self, input: torch.Tensor, drop_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
761
  """
762
  :param input: A tensor of shape `(batch_size, seq_len, embed_dim)`
763
- :param drop_mask: A tensor of shape `(batch_size, seq_len)` with values of zero or one.
764
  """
765
  if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
766
  return input
767
  else:
768
- if self.mask_p > 0. and self.training:
769
- assert drop_mask is not None
770
- drop_mask = drop_mask.to(input.dtype)
771
- keep_prob = 1.0 - self.p
772
- keep_prob2 = 1.0 - self.mask_p
773
- keep_prob = drop_mask * keep_prob2 + (1 - drop_mask) * keep_prob
774
- keep_prob = keep_prob.unsqueeze(-1)
775
- dropout_shape = list(input.shape)
776
- keep_prob = keep_prob.broadcast_to(dropout_shape)
777
- multiplier = input.new_empty(dropout_shape).bernoulli_(keep_prob)
778
- multiplier.div_(keep_prob)
779
- return input * multiplier
780
- elif self.p > 0. and len(self.broadcast_dims) > 0 and self.training:
781
  keep_prob = 1.0 - self.p
782
  dropout_shape = list(input.shape)
783
  for dim in self.broadcast_dims:
@@ -792,7 +598,6 @@ class Dropout(nn.Dropout):
792
 
793
  @dataclass
794
  class VisionBackboneConfig:
795
- image_model_type: str = "openai"
796
  image_default_input_size: Tuple[int, int] = (336, 336)
797
  image_patch_size: int = 14
798
  image_pos_patch_size: int = 14
@@ -832,17 +637,12 @@ class FullMolmoConfig:
832
  mlp_ratio: int = 4
833
  mlp_hidden_size: Optional[int] = None
834
  activation_type: str = "swiglu"
835
- block_type: str = "sequential"
836
  block_group_size: int = 1
837
- alibi: bool = False
838
- alibi_bias_max: float = 8.0
839
- rope: bool = False
840
  rope_full_precision: bool = True
841
  rope_theta: float = 10000.
842
- rope_impl: str = "cockatoo"
843
  vision_backbone: Optional[VisionBackboneConfig] = None
844
- vit_load_path: Optional[str] = None
845
- llm_load_path: Optional[str] = None
846
  attention_type: str = "sdpa"
847
  float32_attention: bool = True
848
  attention_dropout: float = 0.1
@@ -850,7 +650,6 @@ class FullMolmoConfig:
850
  multi_query_attention: Optional[bool] = None
851
  attention_layer_norm: bool = False
852
  residual_dropout: float = 0.1
853
- response_residual_dropout: float = 0.0
854
  embedding_dropout: float = 0.1
855
  layer_norm_type: str = "default"
856
  layer_norm_with_affine: bool = True
@@ -872,10 +671,6 @@ class FullMolmoConfig:
872
  init_cutoff_factor: Optional[float] = None
873
  norm_after: bool = False
874
  precision: Optional[str] = None
875
- max_crops: int = 12
876
- crop_mode: str = "patchify-v2-and-resize-c2"
877
- do_random_scale: bool = True
878
- use_col_tokens: bool = True
879
  image_padding_embed: Optional[str] = None
880
  vit_layers: Tuple = (-1,)
881
  image_pooling_h: int = 2
@@ -883,12 +678,9 @@ class FullMolmoConfig:
883
  image_pooling_2d: str = "attention"
884
  image_projector: str = "mlp"
885
  image_feature_dropout: float = 0.0
886
- use_cls_feature: bool = False
887
  initializer_range: float = 0.02
888
- pad_tokenizer: bool = False
889
  normalize_input_embeds: bool = False
890
  use_position_ids: bool = True
891
- query_pre_attn_scalar: int = 224
892
 
893
  @property
894
  def effective_n_kv_heads(self) -> int:
@@ -1112,7 +904,7 @@ class VisionTransformer(nn.Module):
1112
  if patch_num is None:
1113
  patch_num = self.config.vision_backbone.image_num_patch
1114
  B, N, D = x.shape
1115
-
1116
  x = self.patch_embedding(x)
1117
 
1118
  # class embeddings and positional embeddings
@@ -1526,15 +1318,6 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
1526
 
1527
  self.num_prefix_tokens = self.image_vit.num_prefix_tokens
1528
  assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
1529
- if config.use_cls_feature:
1530
- assert self.num_prefix_tokens > 0, "The model does not have a CLS token"
1531
- nlayers = 1 if config.vit_layers is None else len(config.vit_layers)
1532
- self.cls_projector = nn.Linear(
1533
- nlayers * v_cfg.image_emb_dim,
1534
- self.input_dim,
1535
- bias=False,
1536
- device=config.init_device,
1537
- )
1538
 
1539
  self.pad_embed = None
1540
  if config.image_padding_embed:
@@ -1551,8 +1334,6 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
1551
  def reset_parameters(self):
1552
  super().reset_parameters()
1553
  self.image_vit.reset_parameters()
1554
- if self.config.use_cls_feature:
1555
- nn.init.xavier_uniform_(self.cls_projector.weight)
1556
 
1557
  def encode_image(self, images: torch.Tensor) -> torch.Tensor:
1558
  """
@@ -1562,7 +1343,7 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
1562
  v_cfg = self.config.vision_backbone
1563
  B, T, N, D = images.shape
1564
 
1565
- mask = torch.all(images.view(B * T, N, D) != -1, dim=(1, 2), keepdim=True)
1566
 
1567
  # Output all hidden states
1568
  # n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim)
@@ -1658,9 +1439,6 @@ class OLMoPretrainedVisionBackbone(OLMoVisionBackbone):
1658
  else:
1659
  image_features = self.image_projector(image_features)
1660
 
1661
- if self.config.use_cls_feature:
1662
- raise NotImplementedError()
1663
-
1664
  # image_features: (batch_size, num_image, num_patch, d_model)
1665
  # cls_embed: (batch_size, num_image, d_model)
1666
  return image_features, cls_embed
@@ -1944,7 +1722,7 @@ class Molmo(nn.Module):
1944
  else:
1945
  self.transformer.update({"blocks": nn.ModuleList(blocks)})
1946
 
1947
- if not (self.config.alibi or self.config.rope):
1948
  self.transformer.update(
1949
  {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1950
  )
@@ -2105,23 +1883,7 @@ class Molmo(nn.Module):
2105
 
2106
  x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
2107
 
2108
- if self.config.use_cls_feature:
2109
- x = torch.cat([x[:, :1], cls_embed, x[:, 1:-num_image]], dim=1)
2110
-
2111
- valid_images = torch.any(
2112
- (image_input_idx >= 0).view(batch_size, num_image, num_patch), dim=-1
2113
- )
2114
- valid_images = valid_images.to(attention_mask.dtype)
2115
- attention_mask = torch.cat(
2116
- [attention_mask[:, :1], valid_images, attention_mask[:, 1:-num_image]],
2117
- dim=1,
2118
- )
2119
- position_ids = torch.clamp(
2120
- torch.cumsum(attention_mask, dim=-1) - 1,
2121
- min=0,
2122
- ).broadcast_to((batch_size, attention_mask.shape[-1]))
2123
-
2124
- if not (self.config.alibi or self.config.rope):
2125
  # Get positional embeddings.
2126
  # shape: (1, seq_len)
2127
  pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
@@ -2151,17 +1913,12 @@ class Molmo(nn.Module):
2151
  if (
2152
  attention_bias is not None
2153
  or attention_mask is not None
2154
- or self.config.alibi
2155
  # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
2156
  # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
2157
  # scores correctly.
2158
  or past_key_values is not None
2159
  ):
2160
- if attention_bias is None and self.config.alibi:
2161
- attention_bias = get_causal_attention_bias(
2162
- self.__cache, past_length + seq_len, x.device
2163
- ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
2164
- elif attention_bias is None:
2165
  attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
2166
  elif attention_bias.dtype in (torch.int8, torch.bool):
2167
  attention_bias = attention_bias.to(dtype=torch.float)
@@ -2196,7 +1953,7 @@ class Molmo(nn.Module):
2196
  all_hidden_states.append(x)
2197
 
2198
  layer_past = None if past_key_values is None else past_key_values[block_idx]
2199
- x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layer_past=layer_past, use_cache=use_cache)
2200
 
2201
  if attn_key_values is not None:
2202
  assert cache is not None
@@ -2215,19 +1972,12 @@ class Molmo(nn.Module):
2215
  ]
2216
  )
2217
  x, cache = block_group(
2218
- x, attention_bias=attention_bias, position_ids=position_ids, drop_mask=response_mask, layers_past=layers_past, use_cache=use_cache
2219
  )
2220
  if attn_key_values is not None:
2221
  assert cache is not None
2222
  attn_key_values.extend(cache)
2223
 
2224
- if images is not None and self.config.use_cls_feature:
2225
- assert num_image is not None
2226
- x = torch.cat(
2227
- [x[:, :1], x[:, num_image+1:], torch.zeros_like(x[:, :num_image])],
2228
- dim=1,
2229
- )
2230
-
2231
  if last_logits_only:
2232
  # shape: (batch_size, 1, d_model)
2233
  if append_last_valid_logits is not None:
@@ -2271,9 +2021,9 @@ class MolmoForCausalLM(PreTrainedModel):
2271
 
2272
  if not model:
2273
  full_config = FullMolmoConfig(
2274
- attention_layer_norm=config.attention_layer_norm,
2275
  image_padding_embed="pad_and_partial_pad",
2276
  image_pooling_2d="attention-meanq",
 
2277
  rope_impl="llama",
2278
  vocab_size=config.vocab_size,
2279
  max_sequence_length=config.max_position_embeddings,
@@ -2282,7 +2032,6 @@ class MolmoForCausalLM(PreTrainedModel):
2282
  embedding_size=config.embedding_size,
2283
  attention_type="sdpa",
2284
  embedding_dropout=0,
2285
- response_residual_dropout=0,
2286
  attention_dropout=0,
2287
  residual_dropout=0,
2288
  rope=True,
@@ -2297,10 +2046,8 @@ class MolmoForCausalLM(PreTrainedModel):
2297
  rope_theta=config.rope_theta,
2298
  layer_norm_eps=config.layer_norm_eps,
2299
  layer_norm_type=config.layer_norm_type,
2300
- pad_tokenizer=True,
2301
  vit_layers=[-2, -9],
2302
  vision_backbone=VisionBackboneConfig(
2303
- image_model_type="openai",
2304
  image_default_input_size=(336, 336),
2305
  image_patch_size=14,
2306
  image_pos_patch_size=14,
 
123
  inv_freq = 1.0 / (self.config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
124
  seq = torch.arange(seq_len, device=device, dtype=torch.float)
125
  freqs = torch.einsum("i , j -> i j", seq, inv_freq)
126
+ if self.config.rope_impl == "interleave":
127
  positions = freqs.repeat_interleave(2, dim=-1)
128
  else:
129
  positions = torch.cat((freqs, freqs), dim=-1)
 
146
  return x.view(B, nh, T, hs)
147
 
148
  def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
149
+ if self.config.rope_impl == "interleave":
150
  return ((t * pos_cos) + (self.rotate_every_two(t) * pos_sin)).to(t.dtype)
151
  else:
152
  return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
 
205
  self._activation_checkpoint_fn = None
206
 
207
  # Dropout.
208
+ self.dropout = Dropout(config.residual_dropout)
209
 
210
  # Layer norms.
211
  self.k_norm: Optional[LayerNormBase] = None
 
298
  k: torch.Tensor,
299
  v: torch.Tensor,
300
  attn_mask: Optional[torch.Tensor] = None,
 
301
  dropout_p: float = 0.0,
302
  response_dropout_p: float = 0.0,
303
  is_causal: bool = False,
 
340
  v: torch.Tensor,
341
  attention_bias: Optional[torch.Tensor] = None,
342
  position_ids: Optional[torch.Tensor] = None,
 
343
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
344
  use_cache: bool = False,
345
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
392
  k,
393
  v,
394
  attn_mask=attention_bias,
 
395
  dropout_p=0.0 if not self.training else self.config.attention_dropout,
396
  response_dropout_p=0.0 if not self.training else self.config.response_attention_dropout,
397
  is_causal=attention_bias is None,
 
408
  x: torch.Tensor,
409
  attention_bias: Optional[torch.FloatTensor] = None,
410
  position_ids: Optional[torch.Tensor] = None,
 
411
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
412
  use_cache: bool = False,
413
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
415
 
416
  @classmethod
417
  def build(cls, layer_id: int, config: MolmoConfig, cache: BufferCache):
418
+ return MolmoSequentialBlock(layer_id, config, cache)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
 
421
  class MolmoSequentialBlock(MolmoBlock):
 
464
  x: torch.Tensor,
465
  attention_bias: Optional[torch.Tensor] = None,
466
  position_ids: Optional[torch.Tensor] = None,
 
467
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
468
  use_cache: bool = False,
469
  ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
 
492
  # Get attention scores.
493
  if self._activation_checkpoint_fn is not None:
494
  att, cache = self._activation_checkpoint_fn( # type: ignore
495
+ self.attention, q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache
496
  )
497
  else:
498
+ att, cache = self.attention(q, k, v, attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
499
 
500
  if self.config.norm_after:
501
  if self._activation_checkpoint_fn is not None:
 
505
 
506
  # Add attention scores.
507
  # shape: (B, T, C)
508
+ x = x + self.dropout(att)
509
 
510
  # Add feed-forward projection.
511
  # shape: (batch_size, seq_len, d_model)
 
530
  else:
531
  x = self.ff_norm(x)
532
 
533
+ x = self.dropout(x)
534
  x = og_x + x
535
 
536
  return x, cache
 
576
  self.mask_p = mask_p
577
  self.broadcast_dims = broadcast_dims
578
 
579
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
580
  """
581
  :param input: A tensor of shape `(batch_size, seq_len, embed_dim)`
 
582
  """
583
  if self.p == 0.0 and (self.mask_p is None or self.mask_p == 0.0):
584
  return input
585
  else:
586
+ if self.p > 0. and len(self.broadcast_dims) > 0 and self.training:
 
 
 
 
 
 
 
 
 
 
 
 
587
  keep_prob = 1.0 - self.p
588
  dropout_shape = list(input.shape)
589
  for dim in self.broadcast_dims:
 
598
 
599
  @dataclass
600
  class VisionBackboneConfig:
 
601
  image_default_input_size: Tuple[int, int] = (336, 336)
602
  image_patch_size: int = 14
603
  image_pos_patch_size: int = 14
 
637
  mlp_ratio: int = 4
638
  mlp_hidden_size: Optional[int] = None
639
  activation_type: str = "swiglu"
 
640
  block_group_size: int = 1
641
+ rope: bool = True
 
 
642
  rope_full_precision: bool = True
643
  rope_theta: float = 10000.
644
+ rope_impl: str = "interleave"
645
  vision_backbone: Optional[VisionBackboneConfig] = None
 
 
646
  attention_type: str = "sdpa"
647
  float32_attention: bool = True
648
  attention_dropout: float = 0.1
 
650
  multi_query_attention: Optional[bool] = None
651
  attention_layer_norm: bool = False
652
  residual_dropout: float = 0.1
 
653
  embedding_dropout: float = 0.1
654
  layer_norm_type: str = "default"
655
  layer_norm_with_affine: bool = True
 
671
  init_cutoff_factor: Optional[float] = None
672
  norm_after: bool = False
673
  precision: Optional[str] = None
 
 
 
 
674
  image_padding_embed: Optional[str] = None
675
  vit_layers: Tuple = (-1,)
676
  image_pooling_h: int = 2
 
678
  image_pooling_2d: str = "attention"
679
  image_projector: str = "mlp"
680
  image_feature_dropout: float = 0.0
 
681
  initializer_range: float = 0.02
 
682
  normalize_input_embeds: bool = False
683
  use_position_ids: bool = True
 
684
 
685
  @property
686
  def effective_n_kv_heads(self) -> int:
 
904
  if patch_num is None:
905
  patch_num = self.config.vision_backbone.image_num_patch
906
  B, N, D = x.shape
907
+
908
  x = self.patch_embedding(x)
909
 
910
  # class embeddings and positional embeddings
 
1318
 
1319
  self.num_prefix_tokens = self.image_vit.num_prefix_tokens
1320
  assert self.num_prefix_tokens in {0, 1}, "Only 0 or 1 prefix tokens are supported"
 
 
 
 
 
 
 
 
 
1321
 
1322
  self.pad_embed = None
1323
  if config.image_padding_embed:
 
1334
  def reset_parameters(self):
1335
  super().reset_parameters()
1336
  self.image_vit.reset_parameters()
 
 
1337
 
1338
  def encode_image(self, images: torch.Tensor) -> torch.Tensor:
1339
  """
 
1343
  v_cfg = self.config.vision_backbone
1344
  B, T, N, D = images.shape
1345
 
1346
+ mask = ~torch.all(images.view(B * T, N, D) == -1, dim=(1, 2), keepdim=True)
1347
 
1348
  # Output all hidden states
1349
  # n_layers x (batch_num_crops, (1+)n_tokens, image_emb_dim)
 
1439
  else:
1440
  image_features = self.image_projector(image_features)
1441
 
 
 
 
1442
  # image_features: (batch_size, num_image, num_patch, d_model)
1443
  # cls_embed: (batch_size, num_image, d_model)
1444
  return image_features, cls_embed
 
1722
  else:
1723
  self.transformer.update({"blocks": nn.ModuleList(blocks)})
1724
 
1725
+ if not self.config.rope:
1726
  self.transformer.update(
1727
  {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1728
  )
 
1883
 
1884
  x[batch_idx[valid], image_input_idx[valid]] += image_features[valid]
1885
 
1886
+ if not self.config.rope:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1887
  # Get positional embeddings.
1888
  # shape: (1, seq_len)
1889
  pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
 
1913
  if (
1914
  attention_bias is not None
1915
  or attention_mask is not None
 
1916
  # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1917
  # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1918
  # scores correctly.
1919
  or past_key_values is not None
1920
  ):
1921
+ if attention_bias is None:
 
 
 
 
1922
  attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1923
  elif attention_bias.dtype in (torch.int8, torch.bool):
1924
  attention_bias = attention_bias.to(dtype=torch.float)
 
1953
  all_hidden_states.append(x)
1954
 
1955
  layer_past = None if past_key_values is None else past_key_values[block_idx]
1956
+ x, cache = block(x, attention_bias=attention_bias, position_ids=position_ids, layer_past=layer_past, use_cache=use_cache)
1957
 
1958
  if attn_key_values is not None:
1959
  assert cache is not None
 
1972
  ]
1973
  )
1974
  x, cache = block_group(
1975
+ x, attention_bias=attention_bias, position_ids=position_ids, layers_past=layers_past, use_cache=use_cache
1976
  )
1977
  if attn_key_values is not None:
1978
  assert cache is not None
1979
  attn_key_values.extend(cache)
1980
 
 
 
 
 
 
 
 
1981
  if last_logits_only:
1982
  # shape: (batch_size, 1, d_model)
1983
  if append_last_valid_logits is not None:
 
2021
 
2022
  if not model:
2023
  full_config = FullMolmoConfig(
 
2024
  image_padding_embed="pad_and_partial_pad",
2025
  image_pooling_2d="attention-meanq",
2026
+ attention_layer_norm=config.attention_layer_norm,
2027
  rope_impl="llama",
2028
  vocab_size=config.vocab_size,
2029
  max_sequence_length=config.max_position_embeddings,
 
2032
  embedding_size=config.embedding_size,
2033
  attention_type="sdpa",
2034
  embedding_dropout=0,
 
2035
  attention_dropout=0,
2036
  residual_dropout=0,
2037
  rope=True,
 
2046
  rope_theta=config.rope_theta,
2047
  layer_norm_eps=config.layer_norm_eps,
2048
  layer_norm_type=config.layer_norm_type,
 
2049
  vit_layers=[-2, -9],
2050
  vision_backbone=VisionBackboneConfig(
 
2051
  image_default_input_size=(336, 336),
2052
  image_patch_size=14,
2053
  image_pos_patch_size=14,