yairschiff commited on
Commit
e504dcb
·
verified ·
1 Parent(s): 6e532f5

Ensure weights are tied for BiMamba (if applicable) when loaded from_pretrained

Browse files
Files changed (1) hide show
  1. modeling_caduceus.py +31 -1
modeling_caduceus.py CHANGED
@@ -360,6 +360,24 @@ class Caduceus(CaduceusPreTrainedModel):
360
  factory_kwargs = {"device": device, "dtype": dtype}
361
  self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def forward(
364
  self,
365
  input_ids: torch.LongTensor = None,
@@ -431,8 +449,12 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
431
  raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
432
  self.lm_head = new_embeddings
433
 
 
 
 
434
  def tie_weights(self):
435
  """Tie weights, accounting for RCPS."""
 
436
  if self.config.rcps:
437
  self.lm_head.set_weight(self.get_input_embeddings().weight)
438
  else:
@@ -445,7 +467,7 @@ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
445
  def set_decoder(self, decoder):
446
  """Set decoder (backbone) for the model."""
447
  self.caduceus = decoder
448
-
449
  def forward(
450
  self,
451
  input_ids: torch.LongTensor = None,
@@ -536,6 +558,13 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
536
  if self.pooling_strategy == "first": # Use embedding of first token in the sequence
537
  return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
538
 
 
 
 
 
 
 
 
539
  def forward(
540
  self,
541
  input_ids: torch.LongTensor = None,
@@ -543,6 +572,7 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
543
  labels: Optional[torch.LongTensor] = None,
544
  output_hidden_states: Optional[bool] = None,
545
  return_dict: Optional[bool] = None,
 
546
  ) -> Union[Tuple, SequenceClassifierOutput]:
547
  r"""
548
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
360
  factory_kwargs = {"device": device, "dtype": dtype}
361
  self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
362
 
363
+ def maybe_weight_tie_mamba(self):
364
+ if getattr(self.config, 'bidirectional', False) and getattr(self.config, 'bidirectional_weight_tie', False):
365
+ if getattr(self.config, 'rcps', False):
366
+ for layer in self.backbone.layers:
367
+ layer.mixer.submodule.mamba_rev.in_proj.weight = layer.mixer.submodule.mamba_fwd.in_proj.weight
368
+ layer.mixer.submodule.mamba_rev.in_proj.bias = layer.mixer.submodule.mamba_fwd.in_proj.bias
369
+ layer.mixer.submodule.mamba_rev.out_proj.weight = layer.mixer.submodule.mamba_fwd.out_proj.weight
370
+ layer.mixer.submodule.mamba_rev.out_proj.bias = layer.mixer.submodule.mamba_fwd.out_proj.bias
371
+ else:
372
+ for layer in self.backbone.layers:
373
+ layer.mixer.mamba_rev.in_proj.weight = layer.mixer.mamba_fwd.in_proj.weight
374
+ layer.mixer.mamba_rev.in_proj.bias = layer.mixer.mamba_fwd.in_proj.bias
375
+ layer.mixer.mamba_rev.out_proj.weight = layer.mixer.mamba_fwd.out_proj.weight
376
+ layer.mixer.mamba_rev.out_proj.bias = layer.mixer.mamba_fwd.out_proj.bias
377
+
378
+ def tie_weights(self):
379
+ self.maybe_weight_tie_mamba()
380
+
381
  def forward(
382
  self,
383
  input_ids: torch.LongTensor = None,
 
449
  raise NotImplementedError("Setting output embeddings for RCPS LM is not supported.")
450
  self.lm_head = new_embeddings
451
 
452
+ def maybe_weight_tie_mamba(self):
453
+ self.caduceus.maybe_weight_tie_mamba()
454
+
455
  def tie_weights(self):
456
  """Tie weights, accounting for RCPS."""
457
+ self.maybe_weight_tie_mamba()
458
  if self.config.rcps:
459
  self.lm_head.set_weight(self.get_input_embeddings().weight)
460
  else:
 
467
  def set_decoder(self, decoder):
468
  """Set decoder (backbone) for the model."""
469
  self.caduceus = decoder
470
+
471
  def forward(
472
  self,
473
  input_ids: torch.LongTensor = None,
 
558
  if self.pooling_strategy == "first": # Use embedding of first token in the sequence
559
  return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
560
 
561
+ def maybe_weight_tie_mamba(self):
562
+ self.caduceus.maybe_weight_tie_mamba()
563
+
564
+ def tie_weights(self):
565
+ self.maybe_weight_tie_mamba()
566
+ super().tie_weights()
567
+
568
  def forward(
569
  self,
570
  input_ids: torch.LongTensor = None,
 
572
  labels: Optional[torch.LongTensor] = None,
573
  output_hidden_states: Optional[bool] = None,
574
  return_dict: Optional[bool] = None,
575
+ **kwargs,
576
  ) -> Union[Tuple, SequenceClassifierOutput]:
577
  r"""
578
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):