emarro commited on
Commit
0215062
1 Parent(s): 10e7eae

Upload AxialCaduceusForMaskedLM

Browse files
Files changed (4) hide show
  1. config.json +4 -0
  2. configuration_caduceus.py +174 -0
  3. modeling_caduceus.py +1645 -0
  4. modeling_rcps.py +243 -0
config.json CHANGED
@@ -4,6 +4,10 @@
4
  "architectures": [
5
  "AxialCaduceusForMaskedLM"
6
  ],
 
 
 
 
7
  "bidirectional": "true,",
8
  "bidirectional_strategy": "add",
9
  "bidirectional_weight_tie": true,
 
4
  "architectures": [
5
  "AxialCaduceusForMaskedLM"
6
  ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_caduceus.AxialCaduceusConfig",
9
+ "AutoModelForMaskedLM": "modeling_caduceus.AxialCaduceusForMaskedLM"
10
+ },
11
  "bidirectional": "true,",
12
  "bidirectional_strategy": "add",
13
  "bidirectional_weight_tie": true,
configuration_caduceus.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus config for Hugging Face.
2
+
3
+ """
4
+
5
+ from typing import Optional, Union
6
+
7
+ from transformers import PretrainedConfig
8
+
9
+
10
+ class CaduceusConfig(PretrainedConfig):
11
+ """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
12
+
13
+ model_type = "caduceus"
14
+
15
+ def __init__(
16
+ self,
17
+ # From original MambaConfig
18
+ d_model: int = 2560,
19
+ d_intermediate: int = 0,
20
+ use_mamba2: bool = False,
21
+ n_layer: int = 64,
22
+ vocab_size: int = 50277,
23
+ ssm_cfg: Optional[dict] = None,
24
+ rms_norm: bool = True,
25
+ residual_in_fp32: bool = True,
26
+ fused_add_norm: bool = True,
27
+ pad_vocab_size_multiple: int = 8,
28
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
29
+ norm_epsilon: float = 1e-5,
30
+ # Used in init_weights
31
+ initializer_cfg: Optional[dict] = None,
32
+ # Caduceus-specific params
33
+ bidirectional: bool = True,
34
+ bidirectional_strategy: Union[str, None] = "add",
35
+ bidirectional_weight_tie: bool = True,
36
+ rcps: bool = False,
37
+ complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
38
+ pos_embeddings: Optional[str] = None,
39
+ row_first: Optional[bool] = True,
40
+ **kwargs,
41
+ ):
42
+ super().__init__(**kwargs)
43
+ self.d_model = d_model
44
+ self.d_intermediate = d_intermediate
45
+ self.use_mamba2 = use_mamba2
46
+ self.n_layer = n_layer
47
+ self.vocab_size = vocab_size
48
+ self.ssm_cfg = ssm_cfg
49
+ self.rms_norm = rms_norm
50
+ self.residual_in_fp32 = residual_in_fp32
51
+ self.fused_add_norm = fused_add_norm
52
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
53
+ self.norm_epsilon = norm_epsilon
54
+ self.initializer_cfg = initializer_cfg
55
+ self.bidirectional = bidirectional
56
+ self.bidirectional_strategy = bidirectional_strategy
57
+ self.bidirectional_weight_tie = bidirectional_weight_tie
58
+ self.rcps = rcps
59
+ self.complement_map = complement_map
60
+ self.pos_embeddings = pos_embeddings
61
+ self.row_first = row_first
62
+
63
+ class AxialCaduceusConfig(PretrainedConfig):
64
+ """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance."""
65
+
66
+ model_type = "axial_caduceus"
67
+
68
+ def __init__(
69
+ self,
70
+ # From original MambaConfig
71
+ d_model: int = 2560,
72
+ d_intermediate: int = 0,
73
+ use_mamba2: bool = False,
74
+ n_layer: int = 64,
75
+ vocab_size: int = 50277,
76
+ ssm_cfg: Optional[dict] = None,
77
+ rms_norm: bool = True,
78
+ residual_in_fp32: bool = True,
79
+ fused_add_norm: bool = True,
80
+ pad_vocab_size_multiple: int = 8,
81
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
82
+ norm_epsilon: float = 1e-5,
83
+ # Used in init_weights
84
+ initializer_cfg: Optional[dict] = None,
85
+ # Caduceus-specific params
86
+ bidirectional: bool = True,
87
+ bidirectional_strategy: Union[str, None] = "add",
88
+ bidirectional_weight_tie: bool = True,
89
+ rcps: bool = False,
90
+ complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
91
+ pos_embeddings: Optional[str] = None,
92
+ row_first: Optional[bool] = True,
93
+ **kwargs,
94
+ ):
95
+ super().__init__(**kwargs)
96
+ self.d_model = d_model
97
+ self.d_intermediate = d_intermediate
98
+ self.use_mamba2 = use_mamba2
99
+ self.n_layer = n_layer
100
+ self.vocab_size = vocab_size
101
+ self.ssm_cfg = ssm_cfg
102
+ self.rms_norm = rms_norm
103
+ self.residual_in_fp32 = residual_in_fp32
104
+ self.fused_add_norm = fused_add_norm
105
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
106
+ self.norm_epsilon = norm_epsilon
107
+ self.initializer_cfg = initializer_cfg
108
+ self.bidirectional = bidirectional
109
+ self.bidirectional_strategy = bidirectional_strategy
110
+ self.bidirectional_weight_tie = bidirectional_weight_tie
111
+ self.rcps = rcps
112
+ self.complement_map = complement_map
113
+ self.pos_embeddings = pos_embeddings
114
+ self.row_first = row_first
115
+
116
+
117
+
118
+ class MixedCaduceusConfig(PretrainedConfig):
119
+ """Config that extends the original CaduceusConfig with params relevant to alternating between attention and caducues"""
120
+
121
+ model_type = "mixed_caduceus"
122
+
123
+ def __init__(
124
+ self,
125
+ # From original MambaConfig
126
+ d_model: int = 2560,
127
+ d_intermediate: int = 0,
128
+ use_mamba2: bool = False,
129
+ n_layer: int = 64,
130
+ vocab_size: int = 50277,
131
+ ssm_cfg: Optional[dict] = None,
132
+ rms_norm: bool = True,
133
+ residual_in_fp32: bool = True,
134
+ fused_add_norm: bool = True,
135
+ pad_vocab_size_multiple: int = 8,
136
+ # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm
137
+ norm_epsilon: float = 1e-5,
138
+ # Used in init_weights
139
+ initializer_cfg: Optional[dict] = None,
140
+ # Caduceus-specific params
141
+ bidirectional: bool = True,
142
+ bidirectional_strategy: Union[str, None] = "add",
143
+ bidirectional_weight_tie: bool = True,
144
+ rcps: bool = False,
145
+ complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead
146
+ # attention specific params
147
+ attn_d_model: int = 128,
148
+ attn_n_heads: int = 16,
149
+ attn_attn_dropout: float = 0.1,
150
+ attn_block_dropout: float = 0.1,
151
+ **kwargs,
152
+ ):
153
+ super().__init__(**kwargs)
154
+ self.d_model = d_model
155
+ self.d_intermediate = d_intermediate
156
+ self.use_mamba2 = use_mamba2
157
+ self.n_layer = n_layer
158
+ self.vocab_size = vocab_size
159
+ self.ssm_cfg = ssm_cfg
160
+ self.rms_norm = rms_norm
161
+ self.residual_in_fp32 = residual_in_fp32
162
+ self.fused_add_norm = fused_add_norm
163
+ self.pad_vocab_size_multiple = pad_vocab_size_multiple
164
+ self.norm_epsilon = norm_epsilon
165
+ self.initializer_cfg = initializer_cfg
166
+ self.bidirectional = bidirectional
167
+ self.bidirectional_strategy = bidirectional_strategy
168
+ self.bidirectional_weight_tie = bidirectional_weight_tie
169
+ self.rcps = rcps
170
+ self.complement_map = complement_map
171
+ self.attn_d_model = attn_d_model
172
+ self.attn_n_heads = attn_n_heads
173
+ self.attn_attn_dropout = attn_attn_dropout
174
+ self.attn_block_dropout = attn_block_dropout
modeling_caduceus.py ADDED
@@ -0,0 +1,1645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Caduceus model for Hugging Face.
2
+
3
+ """
4
+
5
+ import math
6
+ from functools import partial
7
+ from typing import Optional, Tuple, Union
8
+
9
+ import torch
10
+ #from mamba_ssm.modules.mamba_simple import Mamba, Block
11
+ #from mamba_ssm.modules import Block
12
+ from mamba_ssm import Mamba, Mamba2
13
+ from mamba_ssm.modules.block import Block
14
+ from mamba_ssm.modules.mlp import GatedMLP
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from torch.nn.parallel import parallel_apply
18
+ from transformers import PreTrainedModel
19
+ from transformers.modeling_outputs import (
20
+ BaseModelOutputWithNoAttention,
21
+ MaskedLMOutput,
22
+ SequenceClassifierOutput,
23
+ )
24
+
25
+ try:
26
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
27
+ except ImportError:
28
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
29
+
30
+ from .configuration_caduceus import CaduceusConfig, MixedCaduceusConfig, AxialCaduceusConfig
31
+ from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
32
+ #from .esm_repo.esm.axial_attention import RowSelfAttention
33
+ #from .esm_repo.esm.modules import NormalizedResidualBlock
34
+
35
+
36
+ def sinusoidal_encoding(positions: torch.Tensor, d_model: int, device=None, dtype=None):
37
+ """
38
+ from https://github.com/wzlxjtu/PositionalEncoding2D
39
+ :param d_model: dimension of the model (d model)
40
+ :param positions: Tensor of the input positions [B, L]
41
+ :return: length*d_model position matrix
42
+ """
43
+ factory_kwargs = {"device": device, "dtype": dtype}
44
+ if d_model % 2 != 0:
45
+ raise ValueError("Cannot use sin/cos positional encoding with "
46
+ "odd dim (got dim={:d})".format(d_model))
47
+ B, L = positions.size()
48
+ pe = torch.zeros(B, L, d_model, **factory_kwargs) # [B, L, D}
49
+
50
+ # position = torch.arange(0, length).unsqueeze(1) #[L, 1]
51
+ position = positions.unsqueeze(-1) # [B,L,1]
52
+ div_term = torch.exp((torch.arange(0, d_model, 2, device=position.device, dtype=torch.float) *
53
+ -(math.log(10000.0) / d_model)))
54
+ pe[:, :, 0::2] = torch.sin(position.float() * div_term)
55
+ pe[:, :, 1::2] = torch.cos(position.float() * div_term)
56
+ pe = pe.to(**factory_kwargs)
57
+ return pe
58
+
59
+ def create_block(
60
+ d_model,
61
+ ssm_cfg=None,
62
+ norm_epsilon=1e-5,
63
+ rms_norm=False,
64
+ residual_in_fp32=False,
65
+ fused_add_norm=False,
66
+ layer_idx=None,
67
+ bidirectional=True,
68
+ bidirectional_strategy="add",
69
+ bidirectional_weight_tie=True,
70
+ rcps=False,
71
+ device=None,
72
+ dtype=None,
73
+ ):
74
+ """Create Caduceus block.
75
+
76
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
77
+ """
78
+ if ssm_cfg is None:
79
+ ssm_cfg = {}
80
+ factory_kwargs = {"device": device, "dtype": dtype}
81
+ bidirectional_kwargs = {
82
+ "bidirectional": bidirectional,
83
+ "bidirectional_strategy": bidirectional_strategy,
84
+ "bidirectional_weight_tie": bidirectional_weight_tie,
85
+ }
86
+ mixer_cls = partial(
87
+ BiMambaWrapper,
88
+ layer_idx=layer_idx,
89
+ **ssm_cfg,
90
+ **bidirectional_kwargs,
91
+ **factory_kwargs,
92
+ )
93
+ norm_cls = partial(
94
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
95
+ )
96
+ block_cls = RCPSMambaBlock if rcps else Block
97
+ d_intermediate=0
98
+ if d_intermediate == 0:
99
+ mlp_cls = nn.Identity
100
+ else:
101
+ mlp_cls = partial(
102
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
103
+ )
104
+ block = block_cls(
105
+ dim=d_model,
106
+ mixer_cls=mixer_cls,
107
+ mlp_cls=mlp_cls,
108
+ norm_cls=norm_cls,
109
+ fused_add_norm=fused_add_norm,
110
+ residual_in_fp32=residual_in_fp32,
111
+ )
112
+ block.layer_idx = layer_idx
113
+ return block
114
+
115
+
116
+ def create_axial_block(
117
+ d_model,
118
+ d_intermediate,
119
+ use_mamba2,
120
+ axis,
121
+ ssm_cfg=None,
122
+ norm_epsilon=1e-5,
123
+ rms_norm=False,
124
+ residual_in_fp32=False,
125
+ fused_add_norm=False,
126
+ layer_idx=None,
127
+ bidirectional=True,
128
+ bidirectional_strategy="add",
129
+ bidirectional_weight_tie=True,
130
+ rcps=False,
131
+ device=None,
132
+ dtype=None,
133
+ ):
134
+ """Create an axial Caduceus block composed of two AxialCaduceus blocks, one for row and one for columns.
135
+
136
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
137
+ """
138
+ if ssm_cfg is None:
139
+ ssm_cfg = {}
140
+ factory_kwargs = {"device": device, "dtype": dtype}
141
+ bidirectional_kwargs = {
142
+ "bidirectional": bidirectional,
143
+ "bidirectional_strategy": bidirectional_strategy,
144
+ "bidirectional_weight_tie": bidirectional_weight_tie,
145
+ }
146
+ #mixer_cls = partial(
147
+ # Mamba2 if ssm_layer == "Mamba2" else Mamba,
148
+ # layer_idx=layer_idx,
149
+ # **ssm_cfg,
150
+ # **factory_kwargs
151
+ #)
152
+
153
+ mixer_cls = partial(
154
+ AxialBiMambaWrapper,
155
+ use_mamba2=use_mamba2,
156
+ axis=axis,
157
+ layer_idx=layer_idx,
158
+ **ssm_cfg,
159
+ **bidirectional_kwargs,
160
+ **factory_kwargs,
161
+ )
162
+ norm_cls = partial(
163
+ nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
164
+ )
165
+ block_cls = RCPSMambaBlock if rcps else Block
166
+ if d_intermediate == 0:
167
+ mlp_cls = nn.Identity
168
+ else:
169
+ mlp_cls = partial(
170
+ GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
171
+ )
172
+
173
+ block = block_cls(
174
+ dim=d_model,
175
+ mixer_cls=mixer_cls,
176
+ mlp_cls=mlp_cls,
177
+ norm_cls=norm_cls,
178
+ fused_add_norm=fused_add_norm,
179
+ residual_in_fp32=residual_in_fp32,
180
+ )
181
+ block.layer_idx = layer_idx
182
+ return block
183
+
184
+ def create_attention_block(
185
+ d_model: int,
186
+ n_heads: int,
187
+ attention_dropout: float,
188
+ block_dropout: float,
189
+ layer_idx=None,
190
+ device=None,
191
+ dtype=None,
192
+ ):
193
+ """Create an RowAttention block from MSATransformer."""
194
+ raise NotImplementedError()
195
+ # factory_kwargs = {"device": device, "dtype": dtype}
196
+ # layer_cls = RowSelfAttention(
197
+ # embed_dim=d_model, num_heads=n_heads, dropout=attention_dropout
198
+ # )
199
+ # block = NormalizedResidualBlock(
200
+ # layer=layer_cls, embedding_dim=d_model, dropout=block_dropout
201
+ # ) # Wraps attention with residual connection, layer norm, and drop out. NOTE: No mixer in this block
202
+ # block = block.to(device)
203
+ # block.layer_idx = layer_idx
204
+ # return block
205
+
206
+
207
+ class BiMambaWrapper(nn.Module):
208
+ """Thin wrapper around Mamba to support bi-directionality."""
209
+
210
+ def __init__(
211
+ self,
212
+ d_model: int,
213
+ bidirectional: bool = True,
214
+ bidirectional_strategy: Optional[str] = "add",
215
+ bidirectional_weight_tie: bool = True,
216
+ **mamba_kwargs,
217
+ ):
218
+ super().__init__()
219
+ if bidirectional and bidirectional_strategy is None:
220
+ bidirectional_strategy = "add" # Default strategy: `add`
221
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
222
+ raise NotImplementedError(
223
+ f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!"
224
+ )
225
+ self.bidirectional = bidirectional
226
+ self.bidirectional_strategy = bidirectional_strategy
227
+ self.mamba_fwd = Mamba(d_model=d_model, **mamba_kwargs)
228
+ if bidirectional:
229
+ self.mamba_rev = Mamba(d_model=d_model, **mamba_kwargs)
230
+ if (
231
+ bidirectional_weight_tie
232
+ ): # Tie in and out projections (where most of param count lies)
233
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
234
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
235
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
236
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
237
+ else:
238
+ self.mamba_rev = None
239
+
240
+ def forward(self, hidden_states, inference_params=None):
241
+ """Bidirectional-enabled forward pass
242
+
243
+ hidden_states: (B, L, D)
244
+ Returns: same shape as hidden_states
245
+ """
246
+ out = self.mamba_fwd(hidden_states, inference_params=inference_params)
247
+ if self.bidirectional:
248
+ out_rev = self.mamba_rev(
249
+ hidden_states.flip(
250
+ dims=(1,)
251
+ ), # Flip along the sequence length dimension
252
+ inference_params=inference_params,
253
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
254
+ if self.bidirectional_strategy == "add":
255
+ out = out + out_rev
256
+ elif self.bidirectional_strategy == "ew_multiply":
257
+ out = out * out_rev
258
+ else:
259
+ raise NotImplementedError(
260
+ f"`{self.bidirectional_strategy}` for bi-directionality not implemented!"
261
+ )
262
+ return out
263
+
264
+
265
+ class AxialBiMambaWrapper(nn.Module):
266
+ """Thin wrapper around BiMamba to support running and aggregating over rows.
267
+ axis=1 for RowMamba, axis=2 for column Mamba
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ d_model: int,
273
+ use_mamba2: bool,
274
+ bidirectional: bool = True,
275
+ bidirectional_strategy: Optional[str] = "add",
276
+ bidirectional_weight_tie: bool = True,
277
+ axis: int = 1,
278
+ **mamba_kwargs,
279
+ ):
280
+ super().__init__()
281
+ if bidirectional and bidirectional_strategy is None:
282
+ bidirectional_strategy = "add" # Default strategy: `add`
283
+ if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
284
+ raise NotImplementedError(
285
+ f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!"
286
+ )
287
+ self.bidirectional = bidirectional
288
+ self.bidirectional_strategy = bidirectional_strategy
289
+ self.mamba_fwd = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs)
290
+ self.axis = axis
291
+ if bidirectional:
292
+ self.mamba_rev = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs)
293
+ if (
294
+ bidirectional_weight_tie
295
+ ): # Tie in and out projections (where most of param count lies)
296
+ self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
297
+ self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
298
+ self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
299
+ self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
300
+ else:
301
+ self.mamba_rev = None
302
+
303
+ def forward(self, hidden_states, inference_params=None):
304
+ """Bidirectional-enabled forward pass
305
+
306
+ hidden_states: (B, R, C, D)
307
+ Returns: same shape as hidden_states
308
+ """
309
+ def apply_mamba(x):
310
+ out = self.mamba_fwd(x, inference_params=inference_params)
311
+ if self.bidirectional:
312
+ out_rev = self.mamba_rev(
313
+ x.flip(
314
+ dims=(1,)
315
+ ), # Flip along the sequence length dimension
316
+ inference_params=inference_params,
317
+ ).flip(dims=(1,)) # Flip back for combining with forward hidden states
318
+ if self.bidirectional_strategy == "add":
319
+ out = out + out_rev
320
+ elif self.bidirectional_strategy == "ew_multiply":
321
+ out = out * out_rev
322
+ else:
323
+ raise NotImplementedError(
324
+ f"`{self.bidirectional_strategy}` for bi-directionality not implemented!"
325
+ )
326
+ return out
327
+ batch, rows, columns, hidden_dim = hidden_states.size()
328
+ if self.axis == 1: # row mamba
329
+ hidden_states = hidden_states.permute(1, 0, 2, 3)
330
+ axis_len = rows
331
+ elif self.axis == 2:
332
+ hidden_states = hidden_states.permute(2, 0, 1, 3)
333
+ axis_len = columns
334
+ outs = []
335
+ ## parllel
336
+ #outs = parallel_apply([apply_mamba for _ in range(axis_len)], hidden_states.unbind(0))
337
+
338
+ ## reshape
339
+ outs = apply_mamba(hidden_states.reshape(axis_len * batch, -1, hidden_dim))
340
+ out = outs.reshape(axis_len, batch, -1, hidden_dim)
341
+
342
+
343
+ ### forlop
344
+ #for axis_idx in range(axis_len):
345
+ #tmp_hidden_states = hidden_states[axis_idx, ...]
346
+ #out = apply_mamba(tmp_hidden_states)
347
+ #outs.append(out)
348
+ #out = torch.stack(outs, dim=0)
349
+ if self.axis == 1: # row mamba
350
+ out = out.permute(1, 0, 2, 3)
351
+ elif self.axis == 2: # [C, B, R, D]
352
+ out = out.permute(1, 2, 0, 3)
353
+ return out
354
+
355
+
356
+ class CaduceusEmbeddings(nn.Module):
357
+ def __init__(
358
+ self,
359
+ config: CaduceusConfig,
360
+ device=None,
361
+ dtype=None,
362
+ ):
363
+ super().__init__()
364
+ factory_kwargs = {"device": device, "dtype": dtype}
365
+ if config.rcps:
366
+ self.word_embeddings = RCPSEmbedding(
367
+ config.vocab_size,
368
+ config.d_model,
369
+ config.complement_map,
370
+ **factory_kwargs,
371
+ )
372
+ else:
373
+ self.word_embeddings = nn.Embedding(
374
+ config.vocab_size, config.d_model, **factory_kwargs
375
+ )
376
+
377
+ def forward(self, input_ids):
378
+ """
379
+ input_ids: (batch, seqlen)
380
+ """
381
+ return self.word_embeddings(input_ids)
382
+
383
+
384
+ class CaduceusMixerModel(nn.Module):
385
+ def __init__(
386
+ self,
387
+ config: CaduceusConfig,
388
+ device=None,
389
+ dtype=None,
390
+ ) -> None:
391
+ super().__init__()
392
+ factory_kwargs = {"device": device, "dtype": dtype}
393
+
394
+ self.fused_add_norm = config.fused_add_norm
395
+ self.rcps = config.rcps
396
+ self.residual_in_fp32 = config.residual_in_fp32
397
+
398
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
399
+
400
+ # Mamba changes the order of residual and layer norm:
401
+ # Instead of LN -> Attn / MLP -> Add, we do:
402
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
403
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
404
+ # This is for performance reason: we can fuse add + layer_norm.
405
+ if config.fused_add_norm:
406
+ if layer_norm_fn is None or rms_norm_fn is None:
407
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
408
+
409
+ self.layers = nn.ModuleList(
410
+ [
411
+ create_block(
412
+ config.d_model,
413
+ ssm_cfg=config.ssm_cfg,
414
+ norm_epsilon=config.norm_epsilon,
415
+ rms_norm=config.rms_norm,
416
+ residual_in_fp32=config.residual_in_fp32,
417
+ fused_add_norm=config.fused_add_norm,
418
+ layer_idx=i,
419
+ bidirectional=config.bidirectional,
420
+ bidirectional_strategy=config.bidirectional_strategy,
421
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
422
+ rcps=config.rcps,
423
+ **factory_kwargs,
424
+ )
425
+ for i in range(config.n_layer)
426
+ ]
427
+ )
428
+
429
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
430
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
431
+ )
432
+ self.norm_f = (
433
+ norm_f
434
+ if (config.fused_add_norm or not config.rcps)
435
+ else RCPSAddNormWrapper(norm_f)
436
+ )
437
+
438
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
439
+ """Mixer forward."""
440
+ all_hidden_states = []
441
+ if inputs_embeds is not None:
442
+ hidden_states = inputs_embeds
443
+ else:
444
+ hidden_states = self.embeddings(input_ids)
445
+
446
+ residual = None
447
+ for layer in self.layers:
448
+ if output_hidden_states:
449
+ all_hidden_states.append(hidden_states)
450
+ # TODO: Add support for gradient checkpointing
451
+ hidden_states, residual = layer(
452
+ hidden_states, residual, inference_params=None
453
+ )
454
+
455
+ if not self.fused_add_norm:
456
+ if self.rcps:
457
+ # Set prenorm=False here since we don't need the residual
458
+ hidden_states = self.norm_f(
459
+ hidden_states, residual=residual, prenorm=False
460
+ )
461
+ else:
462
+ residual = (
463
+ (hidden_states + residual)
464
+ if residual is not None
465
+ else hidden_states
466
+ )
467
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
468
+ else:
469
+ fused_add_norm_fn = (
470
+ rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
471
+ )
472
+ if self.rcps:
473
+ # Set prenorm=False here since we don't need the residual
474
+ hidden_states_fwd = fused_add_norm_fn(
475
+ hidden_states[..., : hidden_states.shape[-1] // 2],
476
+ self.norm_f.weight,
477
+ self.norm_f.bias,
478
+ eps=self.norm_f.eps,
479
+ residual=residual[..., : hidden_states.shape[-1] // 2],
480
+ prenorm=False,
481
+ residual_in_fp32=self.residual_in_fp32,
482
+ )
483
+ hidden_states_rc = fused_add_norm_fn(
484
+ hidden_states[..., hidden_states.shape[-1] // 2 :].flip(
485
+ dims=[-2, -1]
486
+ ),
487
+ self.norm_f.weight,
488
+ self.norm_f.bias,
489
+ eps=self.norm_f.eps,
490
+ residual=residual[..., hidden_states.shape[-1] // 2 :].flip(
491
+ dims=[-2, -1]
492
+ ),
493
+ prenorm=False,
494
+ residual_in_fp32=self.residual_in_fp32,
495
+ )
496
+ hidden_states = torch.cat(
497
+ [hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1
498
+ )
499
+ else:
500
+ # Set prenorm=False here since we don't need the residual
501
+ hidden_states = fused_add_norm_fn(
502
+ hidden_states,
503
+ self.norm_f.weight,
504
+ self.norm_f.bias,
505
+ eps=self.norm_f.eps,
506
+ residual=residual,
507
+ prenorm=False,
508
+ residual_in_fp32=self.residual_in_fp32,
509
+ )
510
+ if output_hidden_states:
511
+ all_hidden_states.append(hidden_states)
512
+ return hidden_states, all_hidden_states
513
+
514
+
515
+ class AxialCaduceusMixerModel(nn.Module):
516
+ def __init__(
517
+ self,
518
+ config: CaduceusConfig,
519
+ device=None,
520
+ dtype=None,
521
+ ) -> None:
522
+ super().__init__()
523
+ factory_kwargs = {"device": device, "dtype": dtype}
524
+
525
+ self.fused_add_norm = config.fused_add_norm
526
+ self.rcps = config.rcps
527
+ self.residual_in_fp32 = config.residual_in_fp32
528
+
529
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
530
+
531
+ self.pos_embeddings = None
532
+ self.add_pos = False
533
+ if config.pos_embeddings == 'Linear':
534
+ self.add_pos = True
535
+ self.pos_embeddings = nn.Linear(in_features=1, out_features=config.d_model, **factory_kwargs)
536
+
537
+ elif config.pos_embeddings == 'Sinusoidal':
538
+ self.pos_embeddings = partial(sinusoidal_encoding, d_model=config.d_model, **factory_kwargs)
539
+
540
+ # Mamba changes the order of residual and layer norm:
541
+ # Instead of LN -> Attn / MLP -> Add, we do:
542
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
543
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
544
+ # This is for performance reason: we can fuse add + layer_norm.
545
+ if config.fused_add_norm:
546
+ if layer_norm_fn is None or rms_norm_fn is None:
547
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
548
+ row_first = 0 #assume col ssm first
549
+ if config.row_first: #row first
550
+ row_first = 1
551
+
552
+ self.layers = nn.ModuleList(
553
+ [
554
+ create_axial_block(
555
+ d_model=config.d_model,
556
+ d_intermediate=config.d_intermediate,
557
+ use_mamba2=config.use_mamba2,
558
+ axis=((i + row_first) % 2) + 1, # (i%2) + 1 for columns first
559
+ ssm_cfg=config.ssm_cfg,
560
+ norm_epsilon=config.norm_epsilon,
561
+ rms_norm=config.rms_norm,
562
+ residual_in_fp32=config.residual_in_fp32,
563
+ fused_add_norm=config.fused_add_norm,
564
+ layer_idx=i,
565
+ bidirectional=config.bidirectional,
566
+ bidirectional_strategy=config.bidirectional_strategy,
567
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
568
+ rcps=config.rcps,
569
+ **factory_kwargs,
570
+ )
571
+ for i in range(config.n_layer * 2)
572
+ ]
573
+ )
574
+
575
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
576
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
577
+ )
578
+ self.norm_f = (
579
+ norm_f
580
+ if (config.fused_add_norm or not config.rcps)
581
+ else RCPSAddNormWrapper(norm_f)
582
+ )
583
+
584
+ def forward(self, input_ids, inputs_embeds=None, input_positions=None, output_hidden_states=False):
585
+ """Mixer forward."""
586
+ all_hidden_states = []
587
+ if inputs_embeds is not None:
588
+ hidden_states = inputs_embeds
589
+ else:
590
+ hidden_states = self.embeddings(input_ids)
591
+ if self.pos_embeddings is not None:
592
+ if self.add_pos:
593
+ pos_embedding = self.pos_embeddings(input_positions[...,None]) #[B, L, D]
594
+ hidden_states = torch.cat([pos_embedding[:,None, ...], hidden_states], dim=1)
595
+ else:
596
+ p_B, p_L = input_positions.size()
597
+ B, R, L, D = hidden_states.size()
598
+ assert p_B == B
599
+ assert p_L == L
600
+ pos_embedding = self.pos_embeddings(positions=input_positions)[:,None, ...] # [B, 1, L, D]
601
+ hidden_states += pos_embedding
602
+
603
+
604
+
605
+ residual = None
606
+ for layer in self.layers:
607
+ if output_hidden_states:
608
+ all_hidden_states.append(hidden_states)
609
+ # TODO: Add support for gradient checkpointing
610
+ hidden_states, residual = layer(
611
+ hidden_states, residual, inference_params=None
612
+ )
613
+
614
+ if not self.fused_add_norm:
615
+ if self.rcps:
616
+ # Set prenorm=False here since we don't need the residual
617
+ hidden_states = self.norm_f(
618
+ hidden_states, residual=residual, prenorm=False
619
+ )
620
+ else:
621
+ residual = (
622
+ (hidden_states + residual)
623
+ if residual is not None
624
+ else hidden_states
625
+ )
626
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
627
+ else:
628
+ fused_add_norm_fn = (
629
+ rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
630
+ )
631
+ if self.rcps:
632
+ # Set prenorm=False here since we don't need the residual
633
+ hidden_states_fwd = fused_add_norm_fn(
634
+ hidden_states[..., : hidden_states.shape[-1] // 2],
635
+ self.norm_f.weight,
636
+ self.norm_f.bias,
637
+ eps=self.norm_f.eps,
638
+ residual=residual[..., : hidden_states.shape[-1] // 2],
639
+ prenorm=False,
640
+ residual_in_fp32=self.residual_in_fp32,
641
+ )
642
+ hidden_states_rc = fused_add_norm_fn(
643
+ hidden_states[..., hidden_states.shape[-1] // 2 :].flip(
644
+ dims=[-2, -1]
645
+ ),
646
+ self.norm_f.weight,
647
+ self.norm_f.bias,
648
+ eps=self.norm_f.eps,
649
+ residual=residual[..., hidden_states.shape[-1] // 2 :].flip(
650
+ dims=[-2, -1]
651
+ ),
652
+ prenorm=False,
653
+ residual_in_fp32=self.residual_in_fp32,
654
+ )
655
+ hidden_states = torch.cat(
656
+ [hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1
657
+ )
658
+ else:
659
+ # Set prenorm=False here since we don't need the residual
660
+ hidden_states = fused_add_norm_fn(
661
+ hidden_states,
662
+ self.norm_f.weight,
663
+ self.norm_f.bias,
664
+ eps=self.norm_f.eps,
665
+ residual=residual,
666
+ prenorm=False,
667
+ residual_in_fp32=self.residual_in_fp32,
668
+ )
669
+ if output_hidden_states:
670
+ all_hidden_states.append(hidden_states)
671
+ if self.pos_embeddings is not None and self.add_pos:
672
+ #removce the positional embeddings form the returned MSA
673
+ hidden_states = hidden_states[:,1:,...]
674
+ return hidden_states, all_hidden_states
675
+
676
+
677
+ class MixedAxialCaduceusMixerModel(nn.Module):
678
+ """
679
+ A model that swtiches between Caducues and Standard attention mechanisms
680
+ """
681
+
682
+ def __init__(
683
+ self,
684
+ config: MixedCaduceusConfig,
685
+ device=None,
686
+ dtype=None,
687
+ ) -> None:
688
+ super().__init__()
689
+ factory_kwargs = {"device": device, "dtype": dtype}
690
+
691
+ self.fused_add_norm = config.fused_add_norm
692
+ self.rcps = config.rcps
693
+ self.residual_in_fp32 = config.residual_in_fp32
694
+
695
+ self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
696
+
697
+ # Mamba changes the order of residual and layer norm:
698
+ # Instead of LN -> Attn / MLP -> Add, we do:
699
+ # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
700
+ # the main branch (output of MLP / Mixer). The model definition is unchanged.
701
+ # This is for performance reason: we can fuse add + layer_norm.
702
+ if config.fused_add_norm:
703
+ if layer_norm_fn is None or rms_norm_fn is None:
704
+ raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
705
+
706
+ layers = []
707
+ for i in range(config.n_layer * 2):
708
+ axis = ((i + 1) % 2) + 1 # 1 for rows, 2 for columns, columns first.
709
+ block = None
710
+ if axis == 1:
711
+ block = create_attention_block(
712
+ d_model=config.attn_d_model,
713
+ n_heads=config.attn_n_heads,
714
+ attention_dropout=config.attn_attn_dropout,
715
+ block_dropout=config.attn_block_dropout,
716
+ layer_idx=i,
717
+ **factory_kwargs,
718
+ )
719
+ elif axis == 2:
720
+ block = create_axial_block(
721
+ d_model=config.d_model,
722
+ d_intermediate=config.d_intermediate,
723
+ use_mamba2=config.use_mamba2,
724
+ axis=axis, # always columns
725
+ ssm_cfg=config.ssm_cfg,
726
+ norm_epsilon=config.norm_epsilon,
727
+ rms_norm=config.rms_norm,
728
+ residual_in_fp32=config.residual_in_fp32,
729
+ fused_add_norm=config.fused_add_norm,
730
+ layer_idx=i,
731
+ bidirectional=config.bidirectional,
732
+ bidirectional_strategy=config.bidirectional_strategy,
733
+ bidirectional_weight_tie=config.bidirectional_weight_tie,
734
+ rcps=config.rcps,
735
+ **factory_kwargs,
736
+ )
737
+ layers.append(block)
738
+
739
+ self.layers = nn.ModuleList(layers)
740
+
741
+ norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
742
+ config.d_model, eps=config.norm_epsilon, **factory_kwargs
743
+ )
744
+ self.norm_f = (
745
+ norm_f
746
+ if (config.fused_add_norm or not config.rcps)
747
+ else RCPSAddNormWrapper(norm_f)
748
+ )
749
+
750
+ def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
751
+ """Mixer forward."""
752
+ all_hidden_states = []
753
+ if inputs_embeds is not None:
754
+ hidden_states = inputs_embeds
755
+ else:
756
+ hidden_states = self.embeddings(input_ids)
757
+
758
+ residual = None
759
+ for layer in self.layers:
760
+ if output_hidden_states:
761
+ all_hidden_states.append(hidden_states)
762
+ # TODO: Add support for gradient checkpointing
763
+ hidden_states, residual = layer(
764
+ hidden_states, residual, inference_params=None
765
+ )
766
+
767
+ if not self.fused_add_norm:
768
+ if self.rcps:
769
+ # Set prenorm=False here since we don't need the residual
770
+ hidden_states = self.norm_f(
771
+ hidden_states, residual=residual, prenorm=False
772
+ )
773
+ else:
774
+ residual = (
775
+ (hidden_states + residual)
776
+ if residual is not None
777
+ else hidden_states
778
+ )
779
+ hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
780
+ else:
781
+ fused_add_norm_fn = (
782
+ rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
783
+ )
784
+ if self.rcps:
785
+ # Set prenorm=False here since we don't need the residual
786
+ hidden_states_fwd = fused_add_norm_fn(
787
+ hidden_states[..., : hidden_states.shape[-1] // 2],
788
+ self.norm_f.weight,
789
+ self.norm_f.bias,
790
+ eps=self.norm_f.eps,
791
+ residual=residual[..., : hidden_states.shape[-1] // 2],
792
+ prenorm=False,
793
+ residual_in_fp32=self.residual_in_fp32,
794
+ )
795
+ hidden_states_rc = fused_add_norm_fn(
796
+ hidden_states[..., hidden_states.shape[-1] // 2 :].flip(
797
+ dims=[-2, -1]
798
+ ),
799
+ self.norm_f.weight,
800
+ self.norm_f.bias,
801
+ eps=self.norm_f.eps,
802
+ residual=residual[..., hidden_states.shape[-1] // 2 :].flip(
803
+ dims=[-2, -1]
804
+ ),
805
+ prenorm=False,
806
+ residual_in_fp32=self.residual_in_fp32,
807
+ )
808
+ hidden_states = torch.cat(
809
+ [hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1
810
+ )
811
+ else:
812
+ # Set prenorm=False here since we don't need the residual
813
+ hidden_states = fused_add_norm_fn(
814
+ hidden_states,
815
+ self.norm_f.weight,
816
+ self.norm_f.bias,
817
+ eps=self.norm_f.eps,
818
+ residual=residual,
819
+ prenorm=False,
820
+ residual_in_fp32=self.residual_in_fp32,
821
+ )
822
+ if output_hidden_states:
823
+ all_hidden_states.append(hidden_states)
824
+ return hidden_states, all_hidden_states
825
+
826
+
827
+ def cross_entropy(logits, y, ignore_index=-100):
828
+ """Cross entropy loss."""
829
+ logits = logits.view(-1, logits.shape[-1])
830
+ y = y.view(-1)
831
+ return F.cross_entropy(logits, y, ignore_index=ignore_index)
832
+
833
+
834
+ def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
835
+ """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
836
+ logits = logits.view(-1, logits.shape[-1])
837
+ y = y.view(-1)
838
+ ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
839
+ loss_weights = loss_weights.view(-1)
840
+ loss_weights[y == ignore_index] = 0.0
841
+ # TODO: Follows GPN implementation, but should we remove weight normalization?
842
+ return (ce * (loss_weights / loss_weights.sum())).sum()
843
+
844
+
845
+ class CaduceusPreTrainedModel(PreTrainedModel):
846
+ """PreTrainedModel wrapper for Caduceus backbone."""
847
+
848
+ config_class = CaduceusConfig
849
+ base_model_prefix = "caduceus"
850
+ supports_gradient_checkpointing = False
851
+ _no_split_modules = ["BiMambaWrapper"]
852
+
853
+ def _init_weights(
854
+ self,
855
+ module,
856
+ initializer_range=0.02, # Now only used for embedding layer.
857
+ **kwargs,
858
+ ):
859
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
860
+
861
+ n_layer = self.config.n_layer
862
+ initialized_cfg = (
863
+ self.config.initializer_cfg
864
+ if self.config.initializer_cfg is not None
865
+ else {}
866
+ )
867
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
868
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
869
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
870
+
871
+ if isinstance(module, nn.Linear):
872
+ if module.bias is not None:
873
+ if not getattr(module.bias, "_no_reinit", False):
874
+ nn.init.zeros_(module.bias)
875
+ elif isinstance(module, nn.Embedding):
876
+ nn.init.normal_(module.weight, std=initializer_range)
877
+
878
+ if rescale_prenorm_residual:
879
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
880
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
881
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
882
+ # residual layers.
883
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
884
+ #
885
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
886
+ for name, p in module.named_parameters():
887
+ if name in ["out_proj.weight", "fc2.weight"]:
888
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
889
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
890
+ # We need to reinit p since this code could be called multiple times
891
+ # Having just p *= scale would repeatedly scale it down
892
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
893
+ with torch.no_grad():
894
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
895
+
896
+ class AxialCaduceusPreTrainedModel(PreTrainedModel):
897
+ """PreTrainedModel wrapper for Caduceus backbone."""
898
+
899
+ config_class = AxialCaduceusConfig
900
+ base_model_prefix = "axial_caduceus"
901
+ supports_gradient_checkpointing = False
902
+ _no_split_modules = ["BiMambaWrapper"]
903
+
904
+ def _init_weights(
905
+ self,
906
+ module,
907
+ initializer_range=0.02, # Now only used for embedding layer.
908
+ **kwargs,
909
+ ):
910
+ """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
911
+
912
+ n_layer = self.config.n_layer
913
+ initialized_cfg = (
914
+ self.config.initializer_cfg
915
+ if self.config.initializer_cfg is not None
916
+ else {}
917
+ )
918
+ rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
919
+ initializer_range = initialized_cfg.get("initializer_range", initializer_range)
920
+ n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
921
+
922
+ if isinstance(module, nn.Linear):
923
+ if module.bias is not None:
924
+ if not getattr(module.bias, "_no_reinit", False):
925
+ nn.init.zeros_(module.bias)
926
+ elif isinstance(module, nn.Embedding):
927
+ nn.init.normal_(module.weight, std=initializer_range)
928
+
929
+ if rescale_prenorm_residual:
930
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
931
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth.
932
+ # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
933
+ # residual layers.
934
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
935
+ #
936
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
937
+ for name, p in module.named_parameters():
938
+ if name in ["out_proj.weight", "fc2.weight"]:
939
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
940
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
941
+ # We need to reinit p since this code could be called multiple times
942
+ # Having just p *= scale would repeatedly scale it down
943
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
944
+ with torch.no_grad():
945
+ p /= math.sqrt(n_residuals_per_layer * n_layer)
946
+
947
+
948
+
949
+ class Caduceus(CaduceusPreTrainedModel):
950
+ """Caduceus model that can be instantiated using HF patterns."""
951
+
952
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
953
+ super().__init__(config)
954
+
955
+ if config.rcps:
956
+ assert (
957
+ config.complement_map is not None
958
+ ), "Complement map must be provided for RCPS."
959
+
960
+ # Adjust vocab size and complement maps if vocab padding is set.
961
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
962
+ config.vocab_size += config.pad_vocab_size_multiple - (
963
+ config.vocab_size % config.pad_vocab_size_multiple
964
+ )
965
+ if config.complement_map is not None and config.vocab_size > len(
966
+ config.complement_map
967
+ ):
968
+ for i in range(len(config.complement_map), config.vocab_size):
969
+ config.complement_map[i] = i
970
+
971
+ self.config = config
972
+ factory_kwargs = {"device": device, "dtype": dtype}
973
+ self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
974
+
975
+ def forward(
976
+ self,
977
+ input_ids: torch.LongTensor = None,
978
+ inputs_embeds: Optional[torch.FloatTensor] = None,
979
+ output_hidden_states: Optional[bool] = None,
980
+ return_dict: Optional[bool] = None,
981
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
982
+ """HF-compatible forward method."""
983
+ output_hidden_states = (
984
+ output_hidden_states
985
+ if output_hidden_states is not None
986
+ else self.config.output_hidden_states
987
+ )
988
+ return_dict = (
989
+ return_dict if return_dict is not None else self.config.use_return_dict
990
+ )
991
+
992
+ hidden_states, all_hidden_states = self.backbone(
993
+ input_ids,
994
+ inputs_embeds=inputs_embeds,
995
+ output_hidden_states=output_hidden_states,
996
+ )
997
+ if return_dict:
998
+ return BaseModelOutputWithNoAttention(
999
+ last_hidden_state=hidden_states,
1000
+ hidden_states=all_hidden_states if output_hidden_states else None,
1001
+ )
1002
+ elif output_hidden_states:
1003
+ return hidden_states, all_hidden_states
1004
+ else:
1005
+ return hidden_states
1006
+
1007
+
1008
+ class AxialCaduceus(AxialCaduceusPreTrainedModel):
1009
+ """Caduceus model that can be instantiated using HF patterns."""
1010
+
1011
+ def __init__(self, config: AxialCaduceusConfig, device=None, dtype=None, **kwargs):
1012
+ super().__init__(config)
1013
+
1014
+ if config.rcps:
1015
+ assert (
1016
+ config.complement_map is not None
1017
+ ), "Complement map must be provided for RCPS."
1018
+
1019
+ # Adjust vocab size and complement maps if vocab padding is set.
1020
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
1021
+ config.vocab_size += config.pad_vocab_size_multiple - (
1022
+ config.vocab_size % config.pad_vocab_size_multiple
1023
+ )
1024
+ if config.complement_map is not None and config.vocab_size > len(
1025
+ config.complement_map
1026
+ ):
1027
+ for i in range(len(config.complement_map), config.vocab_size):
1028
+ config.complement_map[i] = i
1029
+
1030
+ self.config = config
1031
+ factory_kwargs = {"device": device, "dtype": dtype}
1032
+ self.backbone = AxialCaduceusMixerModel(config, **factory_kwargs, **kwargs)
1033
+
1034
+ def forward(
1035
+ self,
1036
+ input_ids: torch.LongTensor = None,
1037
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1038
+ input_positions: Optional[torch.LongTensor] = None,
1039
+ output_hidden_states: Optional[bool] = None,
1040
+ return_dict: Optional[bool] = None,
1041
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
1042
+ """HF-compatible forward method."""
1043
+ output_hidden_states = (
1044
+ output_hidden_states
1045
+ if output_hidden_states is not None
1046
+ else self.config.output_hidden_states
1047
+ )
1048
+ return_dict = (
1049
+ return_dict if return_dict is not None else self.config.use_return_dict
1050
+ )
1051
+
1052
+ hidden_states, all_hidden_states = self.backbone(
1053
+ input_ids,
1054
+ inputs_embeds=inputs_embeds,
1055
+ input_positions=input_positions,
1056
+ output_hidden_states=output_hidden_states,
1057
+ )
1058
+ if return_dict:
1059
+ return BaseModelOutputWithNoAttention(
1060
+ last_hidden_state=hidden_states,
1061
+ hidden_states=all_hidden_states if output_hidden_states else None,
1062
+ )
1063
+ elif output_hidden_states:
1064
+ return hidden_states, all_hidden_states
1065
+ else:
1066
+ return hidden_states
1067
+
1068
+
1069
+ class MixedAxialCaduceus(CaduceusPreTrainedModel):
1070
+ """Mixed Caduceus/Attention model that can be instantiated using HF patterns."""
1071
+
1072
+ def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs):
1073
+ super().__init__(config)
1074
+
1075
+ if config.rcps:
1076
+ assert (
1077
+ config.complement_map is not None
1078
+ ), "Complement map must be provided for RCPS."
1079
+
1080
+ # Adjust vocab size and complement maps if vocab padding is set.
1081
+ if config.vocab_size % config.pad_vocab_size_multiple != 0:
1082
+ config.vocab_size += config.pad_vocab_size_multiple - (
1083
+ config.vocab_size % config.pad_vocab_size_multiple
1084
+ )
1085
+ if config.complement_map is not None and config.vocab_size > len(
1086
+ config.complement_map
1087
+ ):
1088
+ for i in range(len(config.complement_map), config.vocab_size):
1089
+ config.complement_map[i] = i
1090
+
1091
+ self.config = config
1092
+ factory_kwargs = {"device": device, "dtype": dtype}
1093
+ self.backbone = MixedAxialCaduceusMixerModel(config, **factory_kwargs, **kwargs)
1094
+
1095
+ def forward(
1096
+ self,
1097
+ input_ids: torch.LongTensor = None,
1098
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1099
+ output_hidden_states: Optional[bool] = None,
1100
+ return_dict: Optional[bool] = None,
1101
+ ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
1102
+ """HF-compatible forward method."""
1103
+ output_hidden_states = (
1104
+ output_hidden_states
1105
+ if output_hidden_states is not None
1106
+ else self.config.output_hidden_states
1107
+ )
1108
+ return_dict = (
1109
+ return_dict if return_dict is not None else self.config.use_return_dict
1110
+ )
1111
+
1112
+ hidden_states, all_hidden_states = self.backbone(
1113
+ input_ids,
1114
+ inputs_embeds=inputs_embeds,
1115
+ output_hidden_states=output_hidden_states,
1116
+ )
1117
+ if return_dict:
1118
+ return BaseModelOutputWithNoAttention(
1119
+ last_hidden_state=hidden_states,
1120
+ hidden_states=all_hidden_states if output_hidden_states else None,
1121
+ )
1122
+ elif output_hidden_states:
1123
+ return hidden_states, all_hidden_states
1124
+ else:
1125
+ return hidden_states
1126
+
1127
+
1128
+ class CaduceusForMaskedLM(CaduceusPreTrainedModel):
1129
+ """HF-compatible Caduceus model for masked language modeling."""
1130
+
1131
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
1132
+ super().__init__(config, **kwargs)
1133
+ factory_kwargs = {"device": device, "dtype": dtype}
1134
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
1135
+ if config.rcps:
1136
+ self.lm_head = RCPSLMHead(
1137
+ complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
1138
+ vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
1139
+ true_dim=config.d_model,
1140
+ dtype=dtype,
1141
+ )
1142
+ else:
1143
+ self.lm_head = nn.Linear(
1144
+ config.d_model,
1145
+ self.config.vocab_size, # Use caduceus config as it might have been updated
1146
+ bias=False,
1147
+ **factory_kwargs,
1148
+ )
1149
+
1150
+ # Initialize weights and apply final processing
1151
+ self.post_init()
1152
+
1153
+ def get_input_embeddings(self):
1154
+ return self.caduceus.backbone.embeddings.word_embeddings
1155
+
1156
+ def set_input_embeddings(self, value):
1157
+ if self.config.rcps:
1158
+ raise NotImplementedError(
1159
+ "Setting input embeddings for RCPS LM is not supported."
1160
+ )
1161
+ self.caduceus.backbone.embeddings.word_embeddings = value
1162
+
1163
+ def get_output_embeddings(self):
1164
+ return self.lm_head
1165
+
1166
+ def set_output_embeddings(self, new_embeddings):
1167
+ """Overrides output embeddings."""
1168
+ if self.config.rcps:
1169
+ raise NotImplementedError(
1170
+ "Setting output embeddings for RCPS LM is not supported."
1171
+ )
1172
+ self.lm_head = new_embeddings
1173
+
1174
+ def tie_weights(self):
1175
+ """Tie weights, accounting for RCPS."""
1176
+ if self.config.rcps:
1177
+ self.lm_head.set_weight(self.get_input_embeddings().weight)
1178
+ else:
1179
+ super().tie_weights()
1180
+
1181
+ def get_decoder(self):
1182
+ """Get decoder (backbone) for the model."""
1183
+ return self.caduceus
1184
+
1185
+ def set_decoder(self, decoder):
1186
+ """Set decoder (backbone) for the model."""
1187
+ self.caduceus = decoder
1188
+
1189
+ def forward(
1190
+ self,
1191
+ input_ids: torch.LongTensor = None,
1192
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1193
+ labels: Optional[torch.LongTensor] = None,
1194
+ loss_weights: Optional[torch.FloatTensor] = None,
1195
+ output_hidden_states: Optional[bool] = None,
1196
+ return_dict: Optional[bool] = None,
1197
+ ) -> Union[Tuple, MaskedLMOutput]:
1198
+ """HF-compatible forward method."""
1199
+
1200
+ output_hidden_states = (
1201
+ output_hidden_states
1202
+ if output_hidden_states is not None
1203
+ else self.config.output_hidden_states
1204
+ )
1205
+ return_dict = (
1206
+ return_dict if return_dict is not None else self.config.use_return_dict
1207
+ )
1208
+
1209
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1210
+ outputs = self.caduceus(
1211
+ input_ids=input_ids,
1212
+ inputs_embeds=inputs_embeds,
1213
+ output_hidden_states=output_hidden_states,
1214
+ return_dict=return_dict,
1215
+ )
1216
+
1217
+ hidden_states = outputs[0]
1218
+ logits = self.lm_head(hidden_states)
1219
+ logits = logits.float()
1220
+
1221
+ loss = None
1222
+ if labels is not None:
1223
+ if loss_weights is not None:
1224
+ loss = weighted_cross_entropy(
1225
+ logits, labels, loss_weights, ignore_index=self.config.pad_token_id
1226
+ )
1227
+ else:
1228
+ loss = cross_entropy(
1229
+ logits, labels, ignore_index=self.config.pad_token_id
1230
+ )
1231
+
1232
+ if not return_dict:
1233
+ output = (logits,) + outputs[1:]
1234
+ return (loss,) + output if loss is not None else output
1235
+
1236
+ return MaskedLMOutput(
1237
+ loss=loss,
1238
+ logits=logits,
1239
+ hidden_states=outputs.hidden_states,
1240
+ )
1241
+
1242
+
1243
+ class AxialCaduceusForMaskedLM(AxialCaduceusPreTrainedModel):
1244
+ """HF-compatible Caduceus model for masked language modeling."""
1245
+
1246
+ def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
1247
+ super().__init__(config, **kwargs)
1248
+ factory_kwargs = {"device": device, "dtype": dtype}
1249
+ self.caduceus = AxialCaduceus(config, **factory_kwargs, **kwargs)
1250
+ if config.rcps:
1251
+ self.lm_head = RCPSLMHead(
1252
+ complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
1253
+ vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
1254
+ true_dim=config.d_model,
1255
+ dtype=dtype,
1256
+ )
1257
+ else:
1258
+ self.lm_head = nn.Linear(
1259
+ config.d_model,
1260
+ self.config.vocab_size, # Use caduceus config as it might have been updated
1261
+ bias=False,
1262
+ **factory_kwargs,
1263
+ )
1264
+
1265
+ # Initialize weights and apply final processing
1266
+ self.post_init()
1267
+
1268
+ def get_input_embeddings(self):
1269
+ return self.caduceus.backbone.embeddings.word_embeddings
1270
+
1271
+ def set_input_embeddings(self, value):
1272
+ if self.config.rcps:
1273
+ raise NotImplementedError(
1274
+ "Setting input embeddings for RCPS LM is not supported."
1275
+ )
1276
+ self.caduceus.backbone.embeddings.word_embeddings = value
1277
+
1278
+ def get_output_embeddings(self):
1279
+ return self.lm_head
1280
+
1281
+ def set_output_embeddings(self, new_embeddings):
1282
+ """Overrides output embeddings."""
1283
+ if self.config.rcps:
1284
+ raise NotImplementedError(
1285
+ "Setting output embeddings for RCPS LM is not supported."
1286
+ )
1287
+ self.lm_head = new_embeddings
1288
+
1289
+ def tie_weights(self):
1290
+ """Tie weights, accounting for RCPS."""
1291
+ if self.config.rcps:
1292
+ self.lm_head.set_weight(self.get_input_embeddings().weight)
1293
+ else:
1294
+ super().tie_weights()
1295
+
1296
+ def get_decoder(self):
1297
+ """Get decoder (backbone) for the model."""
1298
+ return self.caduceus
1299
+
1300
+ def set_decoder(self, decoder):
1301
+ """Set decoder (backbone) for the model."""
1302
+ self.caduceus = decoder
1303
+
1304
+ def forward(
1305
+ self,
1306
+ input_ids: torch.LongTensor = None,
1307
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1308
+ input_positions: Optional[torch.LongTensor] = None,
1309
+ labels: Optional[torch.LongTensor] = None,
1310
+ loss_weights: Optional[torch.FloatTensor] = None,
1311
+ output_hidden_states: Optional[bool] = None,
1312
+ return_dict: Optional[bool] = None,
1313
+ ) -> Union[Tuple, MaskedLMOutput]:
1314
+ """HF-compatible forward method."""
1315
+
1316
+ output_hidden_states = (
1317
+ output_hidden_states
1318
+ if output_hidden_states is not None
1319
+ else self.config.output_hidden_states
1320
+ )
1321
+ return_dict = (
1322
+ return_dict if return_dict is not None else self.config.use_return_dict
1323
+ )
1324
+
1325
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1326
+ outputs = self.caduceus(
1327
+ input_ids=input_ids,
1328
+ inputs_embeds=inputs_embeds,
1329
+ input_positions=input_positions,
1330
+ output_hidden_states=output_hidden_states,
1331
+ return_dict=return_dict,
1332
+ )
1333
+
1334
+ hidden_states = outputs[0]
1335
+ logits = self.lm_head(hidden_states)
1336
+ logits = logits.float()
1337
+
1338
+ loss = None
1339
+ if labels is not None:
1340
+ if loss_weights is not None:
1341
+ loss = weighted_cross_entropy(
1342
+ logits, labels, loss_weights, ignore_index=self.config.pad_token_id
1343
+ )
1344
+ else:
1345
+ loss = cross_entropy(
1346
+ logits, labels, ignore_index=self.config.pad_token_id
1347
+ )
1348
+
1349
+ if not return_dict:
1350
+ output = (logits,) + outputs[1:]
1351
+ return (loss,) + output if loss is not None else output
1352
+
1353
+ return MaskedLMOutput(
1354
+ loss=loss,
1355
+ logits=logits,
1356
+ hidden_states=outputs.hidden_states,
1357
+ )
1358
+
1359
+
1360
+ class MixedAxialCaduceusForMaskedLM(CaduceusPreTrainedModel):
1361
+ """HF-compatible Caduceus model for masked language modeling."""
1362
+
1363
+ def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs):
1364
+ super().__init__(config, **kwargs)
1365
+ factory_kwargs = {"device": device, "dtype": dtype}
1366
+ self.caduceus = MixedAxialCaduceus(config, **factory_kwargs, **kwargs)
1367
+ if config.rcps:
1368
+ self.lm_head = RCPSLMHead(
1369
+ complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
1370
+ vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
1371
+ true_dim=config.d_model,
1372
+ dtype=dtype,
1373
+ )
1374
+ else:
1375
+ self.lm_head = nn.Linear(
1376
+ config.d_model,
1377
+ self.config.vocab_size, # Use caduceus config as it might have been updated
1378
+ bias=False,
1379
+ **factory_kwargs,
1380
+ )
1381
+
1382
+ # Initialize weights and apply final processing
1383
+ self.post_init()
1384
+
1385
+ def get_input_embeddings(self):
1386
+ return self.caduceus.backbone.embeddings.word_embeddings
1387
+
1388
+ def set_input_embeddings(self, value):
1389
+ if self.config.rcps:
1390
+ raise NotImplementedError(
1391
+ "Setting input embeddings for RCPS LM is not supported."
1392
+ )
1393
+ self.caduceus.backbone.embeddings.word_embeddings = value
1394
+
1395
+ def get_output_embeddings(self):
1396
+ return self.lm_head
1397
+
1398
+ def set_output_embeddings(self, new_embeddings):
1399
+ """Overrides output embeddings."""
1400
+ if self.config.rcps:
1401
+ raise NotImplementedError(
1402
+ "Setting output embeddings for RCPS LM is not supported."
1403
+ )
1404
+ self.lm_head = new_embeddings
1405
+
1406
+ def tie_weights(self):
1407
+ """Tie weights, accounting for RCPS."""
1408
+ if self.config.rcps:
1409
+ self.lm_head.set_weight(self.get_input_embeddings().weight)
1410
+ else:
1411
+ super().tie_weights()
1412
+
1413
+ def get_decoder(self):
1414
+ """Get decoder (backbone) for the model."""
1415
+ return self.caduceus
1416
+
1417
+ def set_decoder(self, decoder):
1418
+ """Set decoder (backbone) for the model."""
1419
+ self.caduceus = decoder
1420
+
1421
+ def forward(
1422
+ self,
1423
+ input_ids: torch.LongTensor = None,
1424
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1425
+ labels: Optional[torch.LongTensor] = None,
1426
+ loss_weights: Optional[torch.FloatTensor] = None,
1427
+ output_hidden_states: Optional[bool] = None,
1428
+ return_dict: Optional[bool] = None,
1429
+ ) -> Union[Tuple, MaskedLMOutput]:
1430
+ """HF-compatible forward method."""
1431
+
1432
+ output_hidden_states = (
1433
+ output_hidden_states
1434
+ if output_hidden_states is not None
1435
+ else self.config.output_hidden_states
1436
+ )
1437
+ return_dict = (
1438
+ return_dict if return_dict is not None else self.config.use_return_dict
1439
+ )
1440
+
1441
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1442
+ outputs = self.caduceus(
1443
+ input_ids=input_ids,
1444
+ inputs_embeds=inputs_embeds,
1445
+ output_hidden_states=output_hidden_states,
1446
+ return_dict=return_dict,
1447
+ )
1448
+
1449
+ hidden_states = outputs[0]
1450
+ logits = self.lm_head(hidden_states)
1451
+ logits = logits.float()
1452
+
1453
+ loss = None
1454
+ if labels is not None:
1455
+ if loss_weights is not None:
1456
+ loss = weighted_cross_entropy(
1457
+ logits, labels, loss_weights, ignore_index=self.config.pad_token_id
1458
+ )
1459
+ else:
1460
+ loss = cross_entropy(
1461
+ logits, labels, ignore_index=self.config.pad_token_id
1462
+ )
1463
+
1464
+ if not return_dict:
1465
+ output = (logits,) + outputs[1:]
1466
+ return (loss,) + output if loss is not None else output
1467
+
1468
+ return MaskedLMOutput(
1469
+ loss=loss,
1470
+ logits=logits,
1471
+ hidden_states=outputs.hidden_states,
1472
+ )
1473
+
1474
+
1475
+ class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
1476
+ def __init__(
1477
+ self,
1478
+ config: CaduceusConfig,
1479
+ pooling_strategy: str = "mean",
1480
+ conjoin_train: bool = False,
1481
+ conjoin_eval: bool = False,
1482
+ device=None,
1483
+ dtype=None,
1484
+ **kwargs,
1485
+ ):
1486
+ super().__init__(config, **kwargs)
1487
+ if pooling_strategy not in ["mean", "max", "first", "last"]:
1488
+ raise NotImplementedError(
1489
+ f"Pooling strategy `{pooling_strategy}` not implemented."
1490
+ )
1491
+ self.pooling_strategy = pooling_strategy
1492
+ factory_kwargs = {"device": device, "dtype": dtype}
1493
+ self.num_labels = kwargs.get("num_labels", config.num_labels)
1494
+ self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
1495
+ self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
1496
+
1497
+ self.conjoin_train = conjoin_train
1498
+ self.conjoin_eval = conjoin_eval
1499
+
1500
+ # Initialize weights and apply final processing
1501
+ self.post_init()
1502
+
1503
+ def get_input_embeddings(self):
1504
+ return self.caduceus.backbone.embeddings.word_embeddings
1505
+
1506
+ def set_input_embeddings(self, value):
1507
+ if self.config.rcps:
1508
+ raise NotImplementedError(
1509
+ "Setting input embeddings for RCPS LM is not supported."
1510
+ )
1511
+ self.caduceus.backbone.embeddings.word_embeddings = value
1512
+
1513
+ def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
1514
+ """Pools hidden states along sequence length dimension."""
1515
+ if (
1516
+ self.pooling_strategy == "mean"
1517
+ ): # Mean pooling along sequence length dimension
1518
+ return hidden_states.mean(dim=sequence_length_dim)
1519
+ if (
1520
+ self.pooling_strategy == "max"
1521
+ ): # Max pooling along sequence length dimension
1522
+ return hidden_states.max(dim=sequence_length_dim).values
1523
+ if (
1524
+ self.pooling_strategy == "last"
1525
+ ): # Use embedding of last token in the sequence
1526
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[
1527
+ -1, ...
1528
+ ]
1529
+ if (
1530
+ self.pooling_strategy == "first"
1531
+ ): # Use embedding of first token in the sequence
1532
+ return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
1533
+
1534
+ def forward(
1535
+ self,
1536
+ input_ids: torch.LongTensor = None,
1537
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1538
+ labels: Optional[torch.LongTensor] = None,
1539
+ output_hidden_states: Optional[bool] = None,
1540
+ return_dict: Optional[bool] = None,
1541
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1542
+ r"""
1543
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1544
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1545
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1546
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1547
+ """
1548
+ return_dict = (
1549
+ return_dict if return_dict is not None else self.config.use_return_dict
1550
+ )
1551
+
1552
+ # Get hidden representations from the backbone
1553
+ if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
1554
+ transformer_outputs = self.caduceus(
1555
+ input_ids,
1556
+ inputs_embeds=inputs_embeds,
1557
+ output_hidden_states=output_hidden_states,
1558
+ return_dict=return_dict,
1559
+ )
1560
+ hidden_states = torch.stack(
1561
+ [
1562
+ transformer_outputs[0][..., : self.config.d_model],
1563
+ torch.flip(
1564
+ transformer_outputs[0][..., self.config.d_model :], dims=[1, 2]
1565
+ ),
1566
+ ],
1567
+ dim=-1,
1568
+ )
1569
+ elif self.conjoin_train or (
1570
+ self.conjoin_eval and not self.training
1571
+ ): # For conjoining / post-hoc conjoining
1572
+ assert input_ids is not None, "`input_ids` must be provided for conjoining."
1573
+ assert (
1574
+ input_ids.ndim == 3
1575
+ ), "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
1576
+ transformer_outputs = self.caduceus(
1577
+ input_ids[..., 0],
1578
+ inputs_embeds=None,
1579
+ output_hidden_states=output_hidden_states,
1580
+ return_dict=return_dict,
1581
+ )
1582
+ transformer_outputs_rc = self.caduceus(
1583
+ input_ids[..., 1],
1584
+ inputs_embeds=None,
1585
+ output_hidden_states=output_hidden_states,
1586
+ return_dict=return_dict,
1587
+ )
1588
+ # Stack along channel dimension (dim=-1)
1589
+ hidden_states = torch.stack(
1590
+ [transformer_outputs[0], transformer_outputs_rc[0]], dim=-1
1591
+ )
1592
+ else:
1593
+ transformer_outputs = self.caduceus(
1594
+ input_ids,
1595
+ inputs_embeds=None,
1596
+ output_hidden_states=output_hidden_states,
1597
+ return_dict=return_dict,
1598
+ )
1599
+ hidden_states = transformer_outputs[0]
1600
+
1601
+ # Pool and get logits
1602
+ pooled_hidden_states = self.pool_hidden_states(hidden_states)
1603
+ # Potentially run `score` twice (with parameters shared) for conjoining
1604
+ if (
1605
+ hidden_states.ndim == 4
1606
+ ): # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
1607
+ logits_fwd = self.score(pooled_hidden_states[..., 0])
1608
+ logits_rc = self.score(pooled_hidden_states[..., 1])
1609
+ logits = (logits_fwd + logits_rc) / 2
1610
+ else:
1611
+ logits = self.score(pooled_hidden_states)
1612
+
1613
+ loss = None
1614
+ if labels is not None:
1615
+ labels = labels.to(logits.device)
1616
+ if self.config.problem_type is None:
1617
+ if self.num_labels == 1:
1618
+ self.config.problem_type = "regression"
1619
+ elif self.num_labels > 1 and (
1620
+ labels.dtype == torch.long or labels.dtype == torch.int
1621
+ ):
1622
+ self.config.problem_type = "single_label_classification"
1623
+ else:
1624
+ self.config.problem_type = "multi_label_classification"
1625
+
1626
+ if self.config.problem_type == "regression":
1627
+ if self.num_labels == 1:
1628
+ loss = F.mse_loss(logits.squeeze(), labels.squeeze())
1629
+ else:
1630
+ loss = F.mse_loss(logits, labels)
1631
+ elif self.config.problem_type == "single_label_classification":
1632
+ loss = F.cross_entropy(
1633
+ logits.view(-1, self.num_labels), labels.view(-1)
1634
+ )
1635
+ elif self.config.problem_type == "multi_label_classification":
1636
+ loss = F.binary_cross_entropy_with_logits(logits, labels)
1637
+ if not return_dict:
1638
+ output = (logits,) + transformer_outputs[1:]
1639
+ return ((loss,) + output) if loss is not None else output
1640
+
1641
+ return SequenceClassifierOutput(
1642
+ loss=loss,
1643
+ logits=logits,
1644
+ hidden_states=transformer_outputs.hidden_states,
1645
+ )
modeling_rcps.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reverse-complement equivariant modules.
2
+
3
+ """
4
+ from collections import OrderedDict
5
+ from typing import Optional
6
+
7
+ import torch
8
+ from torch import Tensor
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ try:
13
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
14
+ except ImportError:
15
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
16
+
17
+
18
+ class RCPSEmbedding(nn.Module):
19
+ """Embedding layer that supports reverse-complement equivariance."""
20
+ def __init__(self, vocab_size: int, d_model: int, complement_map: dict, **factory_kwargs):
21
+ """
22
+ Args:
23
+ vocab_size: Size of vocabulary.
24
+ d_model: Dimensionality of embedding (actual embedding matrix will have 1/2 the output dim).
25
+ complement_map: Dictionary mapping each token id to its complement.
26
+ """
27
+ super().__init__()
28
+ self.register_buffer(
29
+ "complement_map",
30
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
31
+ )
32
+ self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)
33
+
34
+ @property
35
+ def weight(self):
36
+ """Embedding weights."""
37
+ return self.embedding.weight
38
+
39
+ def set_weight(self, value):
40
+ """Set embedding weights."""
41
+ self.embedding.weight = value
42
+
43
+ def rc(self, x):
44
+ """Reverse-complement a tensor of input_ids by flipping along length dimension and complementing the ids."""
45
+ return torch.gather(
46
+ self.complement_map.unsqueeze(0).expand(x.shape[0], -1),
47
+ dim=1,
48
+ index=torch.flip(x, dims=[-1])
49
+ )
50
+
51
+ def forward(self, input_ids):
52
+ """Reverse-complement equivariant forward pass.
53
+
54
+ This embedding module doubles the output dimensionality to support reverse-complement equivariance.
55
+
56
+ Args:
57
+ input_ids: Input tensor of shape (batch_size, seq_len)
58
+ Returns:
59
+ Embedding tensor of shape (batch_size, seq_len, d_model * 2)
60
+ """
61
+ fwd_out = self.embedding(input_ids)
62
+ rc_out = torch.flip(self.embedding(self.rc(input_ids)), dims=[-2, -1])
63
+
64
+ return torch.cat([fwd_out, rc_out], dim=-1)
65
+
66
+
67
+ class RCPSWrapper(nn.Module):
68
+ """Wrapper to convert arbitrary nn.Module into a reverse-complement equivariant module.
69
+
70
+ See ref. "Towards a Better Understanding of Reverse-Complement Equivariance for Deep Learning Models in Regulatory
71
+ Genomics", Zhou et al. (2022), https://proceedings.mlr.press/v165/zhou22a.html for more details.
72
+ """
73
+ def __init__(self, submodule: nn.Module):
74
+ super().__init__()
75
+ self.submodule = submodule
76
+
77
+ @staticmethod
78
+ def rc(x):
79
+ """Reverse-complement a tensor by flipping the length (dim=-2) and channel (dim=-1) dimensions."""
80
+ return torch.flip(x, dims=[-2, -1])
81
+
82
+ def forward(self, x, **kwargs):
83
+ """Reverse-complement equivariant forward pass.
84
+
85
+ Args:
86
+ x: Input tensor of shape (batch_size, seq_len, channels)
87
+ Returns:
88
+ Output tensor of shape (batch_size, seq_len, channels * 2)
89
+ """
90
+ n_channels = x.shape[-1]
91
+ # Run submodule along sequence
92
+ fwd_out = self.submodule(x[..., :n_channels // 2], **kwargs)
93
+ # Run submodule along rc-sequence
94
+ rc_out = self.submodule(self.rc(x[..., n_channels // 2:]), **kwargs)
95
+ # Concatenate along channel dimension (dim=-1)
96
+ return torch.cat([fwd_out, self.rc(rc_out)], dim=-1)
97
+
98
+
99
+ class RCPSAddNormWrapper(RCPSWrapper):
100
+ """RC equivariant AddNorm layer."""
101
+ def __init__(self, submodule: nn.Module):
102
+ super().__init__(submodule)
103
+
104
+ def forward(self, x, residual=None, prenorm=False):
105
+ """
106
+ Args:
107
+ x: Input tensor of shape (batch_size, seq_len, channels)
108
+ residual: Residual tensor of shape (batch_size, seq_len, channels) or None.
109
+ prenorm: Whether to return residual.
110
+ """
111
+ n_channels = x.shape[-1]
112
+ if residual is None:
113
+ residual = x
114
+ x_fwd = self.submodule(x[..., :n_channels // 2].to(dtype=self.submodule.weight.dtype))
115
+ x_rc = self.submodule(self.rc(x[..., n_channels // 2:]).to(dtype=self.submodule.weight.dtype))
116
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
117
+ else:
118
+ residual_fwd = x[..., :n_channels // 2] + residual[..., :n_channels // 2]
119
+ x_fwd = self.submodule(residual_fwd.to(dtype=self.submodule.weight.dtype))
120
+
121
+ residual_rc = self.rc(x[..., n_channels // 2:]) + self.rc(residual[..., n_channels // 2:])
122
+ x_rc = self.submodule(residual_rc.to(dtype=self.submodule.weight.dtype))
123
+
124
+ residual = torch.cat([residual_fwd, self.rc(residual_rc)], dim=-1)
125
+ x = torch.cat([x_fwd, self.rc(x_rc)], dim=-1)
126
+
127
+ return x if not prenorm else (x, residual)
128
+
129
+
130
+ class RCPSMambaBlock(nn.Module):
131
+ def __init__(
132
+ self,
133
+ dim,
134
+ mixer_cls,
135
+ norm_cls=nn.LayerNorm,
136
+ fused_add_norm=False,
137
+ residual_in_fp32=False,
138
+ device=None, # Keep for consistency with original Mamba Block
139
+ dtype=None, # Keep for consistency with original Mamba Block
140
+ ):
141
+ """RCPS version of simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection.
142
+
143
+ Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py
144
+ """
145
+ super().__init__()
146
+ self.residual_in_fp32 = residual_in_fp32
147
+ self.fused_add_norm = fused_add_norm
148
+ self.mixer = RCPSWrapper(mixer_cls(dim))
149
+ norm_f = norm_cls(dim)
150
+ self.norm = norm_f if fused_add_norm else RCPSAddNormWrapper(norm_f)
151
+ if self.fused_add_norm:
152
+ assert RMSNorm is not None, "RMSNorm import fails"
153
+ assert isinstance(
154
+ self.norm, (nn.LayerNorm, RMSNorm)
155
+ ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
156
+
157
+ def forward(
158
+ self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
159
+ ):
160
+ r"""Pass the input through the encoder layer.
161
+
162
+ Args:
163
+ hidden_states: the sequence to the encoder layer (required).
164
+ residual: hidden_states = Mixer(LN(residual)).
165
+ inference_params: inference parameters for mixer.
166
+ """
167
+ if not self.fused_add_norm:
168
+ hidden_states, residual = self.norm(hidden_states, residual=residual, prenorm=True)
169
+ if self.residual_in_fp32:
170
+ residual = residual.to(torch.float32)
171
+ else:
172
+ fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
173
+
174
+ hidden_states_fwd, residual_fwd = fused_add_norm_fn(
175
+ hidden_states[..., hidden_states.shape[-1] // 2:],
176
+ self.norm.weight,
177
+ self.norm.bias,
178
+ residual=residual[..., hidden_states.shape[-1] // 2:] if residual is not None else None,
179
+ prenorm=True,
180
+ residual_in_fp32=self.residual_in_fp32,
181
+ eps=self.norm.eps,
182
+ )
183
+
184
+ hidden_states_rc, residual_rc = fused_add_norm_fn(
185
+ hidden_states[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]),
186
+ self.norm.weight,
187
+ self.norm.bias,
188
+ residual=residual[..., :hidden_states.shape[-1] // 2].flip(dims=[-2, -1]) if residual is not None else None,
189
+ prenorm=True,
190
+ residual_in_fp32=self.residual_in_fp32,
191
+ eps=self.norm.eps,
192
+ )
193
+ hidden_states = torch.cat([hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1)
194
+ residual = torch.cat([residual_fwd, residual_rc.flip(dims=[-2, -1])], dim=-1)
195
+ hidden_states = self.mixer(hidden_states, inference_params=inference_params)
196
+ return hidden_states, residual
197
+
198
+ def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
199
+ """Allocate inference cache for mixer.
200
+
201
+ Keep for compatibility with original Mamba Block.
202
+ """
203
+ return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
204
+
205
+
206
+ class RCPSLMHead(nn.Module):
207
+ """LM Head for reverse-complement equivariant inputs, which have dim * 2 relative to standard inputs."""
208
+ def __init__(self, true_dim: int, vocab_size: int, complement_map: dict, **factory_kwargs):
209
+ """
210
+ `true_dim` corresponds to the actual dimensionality of the input were it not reverse-complement
211
+ equivariant, i.e. 0.5 times the actual input dim.
212
+ """
213
+ super().__init__()
214
+ self.register_buffer(
215
+ "complement_map",
216
+ torch.tensor(list(OrderedDict(complement_map).values()), dtype=torch.long)
217
+ )
218
+ self.true_dim = true_dim
219
+ self.lm_head = nn.Linear(true_dim, vocab_size, bias=False, **factory_kwargs)
220
+
221
+ @property
222
+ def weight(self):
223
+ """LM head weights."""
224
+ return self.lm_head.weight
225
+
226
+ def set_weight(self, value):
227
+ """Set LM head weights."""
228
+ self.lm_head.weight = value
229
+
230
+ def forward(self, x):
231
+ """
232
+ Args:
233
+ x: Input tensor of shape (batch_size, seq_len, dim), where dim = 2 * true_dim.
234
+ """
235
+ n_channels = x.shape[-1]
236
+ assert n_channels == 2 * self.true_dim, "Input must have 2 * true_dim channels."
237
+ fwd_logits = F.linear(x[..., :n_channels // 2], self.weight, bias=self.lm_head.bias)
238
+ rc_logits = F.linear(
239
+ torch.flip(x[..., n_channels // 2:], dims=[-1]),
240
+ self.weight[self.complement_map, :],
241
+ bias=self.lm_head.bias
242
+ )
243
+ return fwd_logits + rc_logits