yairschiff commited on
Commit
9ae8522
·
verified ·
1 Parent(s): 3651f60

Enable mambav2 compat

Browse files
Files changed (1) hide show
  1. modeling_caduceus.py +33 -21
modeling_caduceus.py CHANGED
@@ -2,21 +2,29 @@
2
 
3
  """
4
 
 
5
  import math
6
  from functools import partial
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
10
- from mamba_ssm.modules.mamba_simple import Mamba, Block
 
 
 
 
11
  from torch import nn
12
  from torch.nn import functional as F
13
  from transformers import PreTrainedModel
14
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
15
 
16
  try:
17
- from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
18
  except ImportError:
19
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
 
 
 
20
 
21
  from .configuration_caduceus import CaduceusConfig
22
  from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
@@ -54,13 +62,24 @@ def create_block(
54
  nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
55
  )
56
  block_cls = RCPSMambaBlock if rcps else Block
57
- block = block_cls(
58
- d_model,
59
- mixer_cls,
60
- norm_cls=norm_cls,
61
- fused_add_norm=fused_add_norm,
62
- residual_in_fp32=residual_in_fp32,
63
- )
 
 
 
 
 
 
 
 
 
 
 
64
  block.layer_idx = layer_idx
65
  return block
66
 
@@ -264,15 +283,15 @@ def cross_entropy(logits, y, ignore_index=-100):
264
  return F.cross_entropy(logits, y, ignore_index=ignore_index)
265
 
266
 
267
- def weighted_cross_entropy(logits, y, loss_weight, ignore_index=-100):
268
  """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
269
  logits = logits.view(-1, logits.shape[-1])
270
  y = y.view(-1)
271
  ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
272
- loss_weight = loss_weight.view(-1)
273
- loss_weight[y == ignore_index] = 0.0
274
  # TODO: Follows GPN implementation, but should we remove weight normalization?
275
- return (ce * (loss_weight / loss_weight.sum())).sum()
276
 
277
 
278
  class CaduceusPreTrainedModel(PreTrainedModel):
@@ -497,12 +516,6 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
497
 
498
  # Initialize weights and apply final processing
499
  self.post_init()
500
- self.init_scorer()
501
-
502
- def init_scorer(self, initializer_range=0.02):
503
- initializer_range = self.config.initializer_cfg.get("initializer_range", initializer_range) \
504
- if self.config.initializer_cfg is not None else initializer_range
505
- self.score.weight.data.normal_(std=initializer_range)
506
 
507
  def get_input_embeddings(self):
508
  return self.caduceus.backbone.embeddings.word_embeddings
@@ -530,7 +543,6 @@ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
530
  labels: Optional[torch.LongTensor] = None,
531
  output_hidden_states: Optional[bool] = None,
532
  return_dict: Optional[bool] = None,
533
- **kwargs,
534
  ) -> Union[Tuple, SequenceClassifierOutput]:
535
  r"""
536
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
2
 
3
  """
4
 
5
+ import inspect
6
  import math
7
  from functools import partial
8
  from typing import Optional, Tuple, Union
9
 
10
  import torch
11
+ from mamba_ssm.modules.mamba_simple import Mamba
12
+ try:
13
+ from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
14
+ except ImportError:
15
+ from mamba_ssm.modules.block import Block # mambav2 file structure
16
  from torch import nn
17
  from torch.nn import functional as F
18
  from transformers import PreTrainedModel
19
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput
20
 
21
  try:
22
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
23
  except ImportError:
24
+ try:
25
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
26
+ except ImportError:
27
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
28
 
29
  from .configuration_caduceus import CaduceusConfig
30
  from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
 
62
  nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
63
  )
64
  block_cls = RCPSMambaBlock if rcps else Block
65
+ # mambav2 compatibility
66
+ if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
67
+ block = block_cls(
68
+ d_model,
69
+ mixer_cls,
70
+ mlp_cls=nn.Identity,
71
+ norm_cls=norm_cls,
72
+ fused_add_norm=fused_add_norm,
73
+ residual_in_fp32=residual_in_fp32,
74
+ )
75
+ else:
76
+ block = block_cls(
77
+ d_model,
78
+ mixer_cls,
79
+ norm_cls=norm_cls,
80
+ fused_add_norm=fused_add_norm,
81
+ residual_in_fp32=residual_in_fp32,
82
+ )
83
  block.layer_idx = layer_idx
84
  return block
85
 
 
283
  return F.cross_entropy(logits, y, ignore_index=ignore_index)
284
 
285
 
286
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
287
  """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
288
  logits = logits.view(-1, logits.shape[-1])
289
  y = y.view(-1)
290
  ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
291
+ loss_weights = loss_weights.view(-1)
292
+ loss_weights[y == ignore_index] = 0.0
293
  # TODO: Follows GPN implementation, but should we remove weight normalization?
294
+ return (ce * (loss_weights / loss_weights.sum())).sum()
295
 
296
 
297
  class CaduceusPreTrainedModel(PreTrainedModel):
 
516
 
517
  # Initialize weights and apply final processing
518
  self.post_init()
 
 
 
 
 
 
519
 
520
  def get_input_embeddings(self):
521
  return self.caduceus.backbone.embeddings.word_embeddings
 
543
  labels: Optional[torch.LongTensor] = None,
544
  output_hidden_states: Optional[bool] = None,
545
  return_dict: Optional[bool] = None,
 
546
  ) -> Union[Tuple, SequenceClassifierOutput]:
547
  r"""
548
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):