|
"""Caduceus model for Hugging Face. |
|
|
|
""" |
|
|
|
import math |
|
from functools import partial |
|
from typing import Optional, Tuple, Union |
|
|
|
import torch |
|
|
|
|
|
from mamba_ssm import Mamba, Mamba2 |
|
from mamba_ssm.modules.block import Block |
|
from mamba_ssm.modules.mlp import GatedMLP |
|
from torch import nn |
|
from torch.nn import functional as F |
|
from torch.nn.parallel import parallel_apply |
|
from transformers import PreTrainedModel |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithNoAttention, |
|
MaskedLMOutput, |
|
SequenceClassifierOutput, |
|
) |
|
|
|
try: |
|
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn |
|
except ImportError: |
|
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None |
|
|
|
from .configuration_caduceus import CaduceusConfig, MixedCaduceusConfig, AxialCaduceusConfig |
|
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock |
|
|
|
|
|
|
|
|
|
def sinusoidal_encoding(positions: torch.Tensor, d_model: int, device=None, dtype=None): |
|
""" |
|
from https://github.com/wzlxjtu/PositionalEncoding2D |
|
:param d_model: dimension of the model (d model) |
|
:param positions: Tensor of the input positions [B, L] |
|
:return: length*d_model position matrix |
|
""" |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
if d_model % 2 != 0: |
|
raise ValueError("Cannot use sin/cos positional encoding with " |
|
"odd dim (got dim={:d})".format(d_model)) |
|
B, L = positions.size() |
|
pe = torch.zeros(B, L, d_model, **factory_kwargs) |
|
|
|
|
|
position = positions.unsqueeze(-1) |
|
div_term = torch.exp((torch.arange(0, d_model, 2, device=position.device, dtype=torch.float) * |
|
-(math.log(10000.0) / d_model))) |
|
pe[:, :, 0::2] = torch.sin(position.float() * div_term) |
|
pe[:, :, 1::2] = torch.cos(position.float() * div_term) |
|
pe = pe.to(**factory_kwargs) |
|
return pe |
|
|
|
def create_block( |
|
d_model, |
|
ssm_cfg=None, |
|
norm_epsilon=1e-5, |
|
rms_norm=False, |
|
residual_in_fp32=False, |
|
fused_add_norm=False, |
|
layer_idx=None, |
|
bidirectional=True, |
|
bidirectional_strategy="add", |
|
bidirectional_weight_tie=True, |
|
rcps=False, |
|
device=None, |
|
dtype=None, |
|
): |
|
"""Create Caduceus block. |
|
|
|
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py |
|
""" |
|
if ssm_cfg is None: |
|
ssm_cfg = {} |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
bidirectional_kwargs = { |
|
"bidirectional": bidirectional, |
|
"bidirectional_strategy": bidirectional_strategy, |
|
"bidirectional_weight_tie": bidirectional_weight_tie, |
|
} |
|
mixer_cls = partial( |
|
BiMambaWrapper, |
|
layer_idx=layer_idx, |
|
**ssm_cfg, |
|
**bidirectional_kwargs, |
|
**factory_kwargs, |
|
) |
|
norm_cls = partial( |
|
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs |
|
) |
|
block_cls = RCPSMambaBlock if rcps else Block |
|
d_intermediate=0 |
|
if d_intermediate == 0: |
|
mlp_cls = nn.Identity |
|
else: |
|
mlp_cls = partial( |
|
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs |
|
) |
|
block = block_cls( |
|
dim=d_model, |
|
mixer_cls=mixer_cls, |
|
mlp_cls=mlp_cls, |
|
norm_cls=norm_cls, |
|
fused_add_norm=fused_add_norm, |
|
residual_in_fp32=residual_in_fp32, |
|
) |
|
block.layer_idx = layer_idx |
|
return block |
|
|
|
|
|
def create_axial_block( |
|
d_model, |
|
d_intermediate, |
|
use_mamba2, |
|
axis, |
|
ssm_cfg=None, |
|
norm_epsilon=1e-5, |
|
rms_norm=False, |
|
residual_in_fp32=False, |
|
fused_add_norm=False, |
|
layer_idx=None, |
|
bidirectional=True, |
|
bidirectional_strategy="add", |
|
bidirectional_weight_tie=True, |
|
rcps=False, |
|
device=None, |
|
dtype=None, |
|
): |
|
"""Create an axial Caduceus block composed of two AxialCaduceus blocks, one for row and one for columns. |
|
|
|
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py |
|
""" |
|
if ssm_cfg is None: |
|
ssm_cfg = {} |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
bidirectional_kwargs = { |
|
"bidirectional": bidirectional, |
|
"bidirectional_strategy": bidirectional_strategy, |
|
"bidirectional_weight_tie": bidirectional_weight_tie, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mixer_cls = partial( |
|
AxialBiMambaWrapper, |
|
use_mamba2=use_mamba2, |
|
axis=axis, |
|
layer_idx=layer_idx, |
|
**ssm_cfg, |
|
**bidirectional_kwargs, |
|
**factory_kwargs, |
|
) |
|
norm_cls = partial( |
|
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs |
|
) |
|
block_cls = RCPSMambaBlock if rcps else Block |
|
if d_intermediate == 0: |
|
mlp_cls = nn.Identity |
|
else: |
|
mlp_cls = partial( |
|
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs |
|
) |
|
|
|
block = block_cls( |
|
dim=d_model, |
|
mixer_cls=mixer_cls, |
|
mlp_cls=mlp_cls, |
|
norm_cls=norm_cls, |
|
fused_add_norm=fused_add_norm, |
|
residual_in_fp32=residual_in_fp32, |
|
) |
|
block.layer_idx = layer_idx |
|
return block |
|
|
|
def create_attention_block( |
|
d_model: int, |
|
n_heads: int, |
|
attention_dropout: float, |
|
block_dropout: float, |
|
layer_idx=None, |
|
device=None, |
|
dtype=None, |
|
): |
|
"""Create an RowAttention block from MSATransformer.""" |
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BiMambaWrapper(nn.Module): |
|
"""Thin wrapper around Mamba to support bi-directionality.""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
bidirectional: bool = True, |
|
bidirectional_strategy: Optional[str] = "add", |
|
bidirectional_weight_tie: bool = True, |
|
**mamba_kwargs, |
|
): |
|
super().__init__() |
|
if bidirectional and bidirectional_strategy is None: |
|
bidirectional_strategy = "add" |
|
if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]: |
|
raise NotImplementedError( |
|
f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!" |
|
) |
|
self.bidirectional = bidirectional |
|
self.bidirectional_strategy = bidirectional_strategy |
|
self.mamba_fwd = Mamba(d_model=d_model, **mamba_kwargs) |
|
if bidirectional: |
|
self.mamba_rev = Mamba(d_model=d_model, **mamba_kwargs) |
|
if ( |
|
bidirectional_weight_tie |
|
): |
|
self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight |
|
self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias |
|
self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight |
|
self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias |
|
else: |
|
self.mamba_rev = None |
|
|
|
def forward(self, hidden_states, inference_params=None): |
|
"""Bidirectional-enabled forward pass |
|
|
|
hidden_states: (B, L, D) |
|
Returns: same shape as hidden_states |
|
""" |
|
out = self.mamba_fwd(hidden_states, inference_params=inference_params) |
|
if self.bidirectional: |
|
out_rev = self.mamba_rev( |
|
hidden_states.flip( |
|
dims=(1,) |
|
), |
|
inference_params=inference_params, |
|
).flip(dims=(1,)) |
|
if self.bidirectional_strategy == "add": |
|
out = out + out_rev |
|
elif self.bidirectional_strategy == "ew_multiply": |
|
out = out * out_rev |
|
else: |
|
raise NotImplementedError( |
|
f"`{self.bidirectional_strategy}` for bi-directionality not implemented!" |
|
) |
|
return out |
|
|
|
|
|
class AxialBiMambaWrapper(nn.Module): |
|
"""Thin wrapper around BiMamba to support running and aggregating over rows. |
|
axis=1 for RowMamba, axis=2 for column Mamba |
|
""" |
|
|
|
def __init__( |
|
self, |
|
d_model: int, |
|
use_mamba2: bool, |
|
bidirectional: bool = True, |
|
bidirectional_strategy: Optional[str] = "add", |
|
bidirectional_weight_tie: bool = True, |
|
axis: int = 1, |
|
**mamba_kwargs, |
|
): |
|
super().__init__() |
|
if bidirectional and bidirectional_strategy is None: |
|
bidirectional_strategy = "add" |
|
if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]: |
|
raise NotImplementedError( |
|
f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!" |
|
) |
|
self.bidirectional = bidirectional |
|
self.bidirectional_strategy = bidirectional_strategy |
|
self.mamba_fwd = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs) |
|
self.axis = axis |
|
if bidirectional: |
|
self.mamba_rev = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs) |
|
if ( |
|
bidirectional_weight_tie |
|
): |
|
self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight |
|
self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias |
|
self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight |
|
self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias |
|
else: |
|
self.mamba_rev = None |
|
|
|
def forward(self, hidden_states, inference_params=None): |
|
"""Bidirectional-enabled forward pass |
|
|
|
hidden_states: (B, R, C, D) |
|
Returns: same shape as hidden_states |
|
""" |
|
def apply_mamba(x): |
|
out = self.mamba_fwd(x, inference_params=inference_params) |
|
if self.bidirectional: |
|
out_rev = self.mamba_rev( |
|
x.flip( |
|
dims=(1,) |
|
), |
|
inference_params=inference_params, |
|
).flip(dims=(1,)) |
|
if self.bidirectional_strategy == "add": |
|
out = out + out_rev |
|
elif self.bidirectional_strategy == "ew_multiply": |
|
out = out * out_rev |
|
else: |
|
raise NotImplementedError( |
|
f"`{self.bidirectional_strategy}` for bi-directionality not implemented!" |
|
) |
|
return out |
|
batch, rows, columns, hidden_dim = hidden_states.size() |
|
if self.axis == 1: |
|
hidden_states = hidden_states.permute(1, 0, 2, 3) |
|
axis_len = rows |
|
elif self.axis == 2: |
|
hidden_states = hidden_states.permute(2, 0, 1, 3) |
|
axis_len = columns |
|
outs = [] |
|
|
|
|
|
|
|
|
|
outs = apply_mamba(hidden_states.reshape(axis_len * batch, -1, hidden_dim)) |
|
out = outs.reshape(axis_len, batch, -1, hidden_dim) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.axis == 1: |
|
out = out.permute(1, 0, 2, 3) |
|
elif self.axis == 2: |
|
out = out.permute(1, 2, 0, 3) |
|
return out |
|
|
|
|
|
class CaduceusEmbeddings(nn.Module): |
|
def __init__( |
|
self, |
|
config: CaduceusConfig, |
|
device=None, |
|
dtype=None, |
|
): |
|
super().__init__() |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
if config.rcps: |
|
self.word_embeddings = RCPSEmbedding( |
|
config.vocab_size, |
|
config.d_model, |
|
config.complement_map, |
|
**factory_kwargs, |
|
) |
|
else: |
|
self.word_embeddings = nn.Embedding( |
|
config.vocab_size, config.d_model, **factory_kwargs |
|
) |
|
|
|
def forward(self, input_ids): |
|
""" |
|
input_ids: (batch, seqlen) |
|
""" |
|
return self.word_embeddings(input_ids) |
|
|
|
|
|
class CaduceusMixerModel(nn.Module): |
|
def __init__( |
|
self, |
|
config: CaduceusConfig, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
super().__init__() |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
|
self.fused_add_norm = config.fused_add_norm |
|
self.rcps = config.rcps |
|
self.residual_in_fp32 = config.residual_in_fp32 |
|
|
|
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if config.fused_add_norm: |
|
if layer_norm_fn is None or rms_norm_fn is None: |
|
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
create_block( |
|
config.d_model, |
|
ssm_cfg=config.ssm_cfg, |
|
norm_epsilon=config.norm_epsilon, |
|
rms_norm=config.rms_norm, |
|
residual_in_fp32=config.residual_in_fp32, |
|
fused_add_norm=config.fused_add_norm, |
|
layer_idx=i, |
|
bidirectional=config.bidirectional, |
|
bidirectional_strategy=config.bidirectional_strategy, |
|
bidirectional_weight_tie=config.bidirectional_weight_tie, |
|
rcps=config.rcps, |
|
**factory_kwargs, |
|
) |
|
for i in range(config.n_layer) |
|
] |
|
) |
|
|
|
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( |
|
config.d_model, eps=config.norm_epsilon, **factory_kwargs |
|
) |
|
self.norm_f = ( |
|
norm_f |
|
if (config.fused_add_norm or not config.rcps) |
|
else RCPSAddNormWrapper(norm_f) |
|
) |
|
|
|
def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False): |
|
"""Mixer forward.""" |
|
all_hidden_states = [] |
|
if inputs_embeds is not None: |
|
hidden_states = inputs_embeds |
|
else: |
|
hidden_states = self.embeddings(input_ids) |
|
|
|
residual = None |
|
for layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
|
|
hidden_states, residual = layer( |
|
hidden_states, residual, inference_params=None |
|
) |
|
|
|
if not self.fused_add_norm: |
|
if self.rcps: |
|
|
|
hidden_states = self.norm_f( |
|
hidden_states, residual=residual, prenorm=False |
|
) |
|
else: |
|
residual = ( |
|
(hidden_states + residual) |
|
if residual is not None |
|
else hidden_states |
|
) |
|
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) |
|
else: |
|
fused_add_norm_fn = ( |
|
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn |
|
) |
|
if self.rcps: |
|
|
|
hidden_states_fwd = fused_add_norm_fn( |
|
hidden_states[..., : hidden_states.shape[-1] // 2], |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual[..., : hidden_states.shape[-1] // 2], |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
hidden_states_rc = fused_add_norm_fn( |
|
hidden_states[..., hidden_states.shape[-1] // 2 :].flip( |
|
dims=[-2, -1] |
|
), |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual[..., hidden_states.shape[-1] // 2 :].flip( |
|
dims=[-2, -1] |
|
), |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
hidden_states = torch.cat( |
|
[hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1 |
|
) |
|
else: |
|
|
|
hidden_states = fused_add_norm_fn( |
|
hidden_states, |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual, |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
return hidden_states, all_hidden_states |
|
|
|
|
|
class AxialCaduceusMixerModel(nn.Module): |
|
def __init__( |
|
self, |
|
config: CaduceusConfig, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
super().__init__() |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
|
self.fused_add_norm = config.fused_add_norm |
|
self.rcps = config.rcps |
|
self.residual_in_fp32 = config.residual_in_fp32 |
|
|
|
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) |
|
|
|
self.pos_embeddings = None |
|
self.add_pos = False |
|
if config.pos_embeddings == 'Linear': |
|
self.add_pos = True |
|
self.pos_embeddings = nn.Linear(in_features=1, out_features=config.d_model, **factory_kwargs) |
|
|
|
elif config.pos_embeddings == 'Sinusoidal': |
|
self.pos_embeddings = partial(sinusoidal_encoding, d_model=config.d_model, **factory_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if config.fused_add_norm: |
|
if layer_norm_fn is None or rms_norm_fn is None: |
|
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") |
|
row_first = 0 |
|
if config.row_first: |
|
row_first = 1 |
|
|
|
self.layers = nn.ModuleList( |
|
[ |
|
create_axial_block( |
|
d_model=config.d_model, |
|
d_intermediate=config.d_intermediate, |
|
use_mamba2=config.use_mamba2, |
|
axis=((i + row_first) % 2) + 1, |
|
ssm_cfg=config.ssm_cfg, |
|
norm_epsilon=config.norm_epsilon, |
|
rms_norm=config.rms_norm, |
|
residual_in_fp32=config.residual_in_fp32, |
|
fused_add_norm=config.fused_add_norm, |
|
layer_idx=i, |
|
bidirectional=config.bidirectional, |
|
bidirectional_strategy=config.bidirectional_strategy, |
|
bidirectional_weight_tie=config.bidirectional_weight_tie, |
|
rcps=config.rcps, |
|
**factory_kwargs, |
|
) |
|
for i in range(config.n_layer * 2) |
|
] |
|
) |
|
|
|
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( |
|
config.d_model, eps=config.norm_epsilon, **factory_kwargs |
|
) |
|
self.norm_f = ( |
|
norm_f |
|
if (config.fused_add_norm or not config.rcps) |
|
else RCPSAddNormWrapper(norm_f) |
|
) |
|
|
|
def forward(self, input_ids, inputs_embeds=None, input_positions=None, output_hidden_states=False): |
|
"""Mixer forward.""" |
|
all_hidden_states = [] |
|
if inputs_embeds is not None: |
|
hidden_states = inputs_embeds |
|
else: |
|
hidden_states = self.embeddings(input_ids) |
|
if self.pos_embeddings is not None: |
|
if self.add_pos: |
|
pos_embedding = self.pos_embeddings(input_positions[...,None]) |
|
hidden_states = torch.cat([pos_embedding[:,None, ...], hidden_states], dim=1) |
|
else: |
|
p_B, p_L = input_positions.size() |
|
B, R, L, D = hidden_states.size() |
|
assert p_B == B |
|
assert p_L == L |
|
pos_embedding = self.pos_embeddings(positions=input_positions)[:,None, ...] |
|
hidden_states += pos_embedding |
|
|
|
|
|
|
|
residual = None |
|
for layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
|
|
hidden_states, residual = layer( |
|
hidden_states, residual, inference_params=None |
|
) |
|
|
|
if not self.fused_add_norm: |
|
if self.rcps: |
|
|
|
hidden_states = self.norm_f( |
|
hidden_states, residual=residual, prenorm=False |
|
) |
|
else: |
|
residual = ( |
|
(hidden_states + residual) |
|
if residual is not None |
|
else hidden_states |
|
) |
|
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) |
|
else: |
|
fused_add_norm_fn = ( |
|
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn |
|
) |
|
if self.rcps: |
|
|
|
hidden_states_fwd = fused_add_norm_fn( |
|
hidden_states[..., : hidden_states.shape[-1] // 2], |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual[..., : hidden_states.shape[-1] // 2], |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
hidden_states_rc = fused_add_norm_fn( |
|
hidden_states[..., hidden_states.shape[-1] // 2 :].flip( |
|
dims=[-2, -1] |
|
), |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual[..., hidden_states.shape[-1] // 2 :].flip( |
|
dims=[-2, -1] |
|
), |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
hidden_states = torch.cat( |
|
[hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1 |
|
) |
|
else: |
|
|
|
hidden_states = fused_add_norm_fn( |
|
hidden_states, |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual, |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
if self.pos_embeddings is not None and self.add_pos: |
|
|
|
hidden_states = hidden_states[:,1:,...] |
|
return hidden_states, all_hidden_states |
|
|
|
|
|
class MixedAxialCaduceusMixerModel(nn.Module): |
|
""" |
|
A model that swtiches between Caducues and Standard attention mechanisms |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config: MixedCaduceusConfig, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
super().__init__() |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
|
self.fused_add_norm = config.fused_add_norm |
|
self.rcps = config.rcps |
|
self.residual_in_fp32 = config.residual_in_fp32 |
|
|
|
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if config.fused_add_norm: |
|
if layer_norm_fn is None or rms_norm_fn is None: |
|
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") |
|
|
|
layers = [] |
|
for i in range(config.n_layer * 2): |
|
axis = ((i + 1) % 2) + 1 |
|
block = None |
|
if axis == 1: |
|
block = create_attention_block( |
|
d_model=config.attn_d_model, |
|
n_heads=config.attn_n_heads, |
|
attention_dropout=config.attn_attn_dropout, |
|
block_dropout=config.attn_block_dropout, |
|
layer_idx=i, |
|
**factory_kwargs, |
|
) |
|
elif axis == 2: |
|
block = create_axial_block( |
|
d_model=config.d_model, |
|
d_intermediate=config.d_intermediate, |
|
use_mamba2=config.use_mamba2, |
|
axis=axis, |
|
ssm_cfg=config.ssm_cfg, |
|
norm_epsilon=config.norm_epsilon, |
|
rms_norm=config.rms_norm, |
|
residual_in_fp32=config.residual_in_fp32, |
|
fused_add_norm=config.fused_add_norm, |
|
layer_idx=i, |
|
bidirectional=config.bidirectional, |
|
bidirectional_strategy=config.bidirectional_strategy, |
|
bidirectional_weight_tie=config.bidirectional_weight_tie, |
|
rcps=config.rcps, |
|
**factory_kwargs, |
|
) |
|
layers.append(block) |
|
|
|
self.layers = nn.ModuleList(layers) |
|
|
|
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( |
|
config.d_model, eps=config.norm_epsilon, **factory_kwargs |
|
) |
|
self.norm_f = ( |
|
norm_f |
|
if (config.fused_add_norm or not config.rcps) |
|
else RCPSAddNormWrapper(norm_f) |
|
) |
|
|
|
def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False): |
|
"""Mixer forward.""" |
|
all_hidden_states = [] |
|
if inputs_embeds is not None: |
|
hidden_states = inputs_embeds |
|
else: |
|
hidden_states = self.embeddings(input_ids) |
|
|
|
residual = None |
|
for layer in self.layers: |
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
|
|
hidden_states, residual = layer( |
|
hidden_states, residual, inference_params=None |
|
) |
|
|
|
if not self.fused_add_norm: |
|
if self.rcps: |
|
|
|
hidden_states = self.norm_f( |
|
hidden_states, residual=residual, prenorm=False |
|
) |
|
else: |
|
residual = ( |
|
(hidden_states + residual) |
|
if residual is not None |
|
else hidden_states |
|
) |
|
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) |
|
else: |
|
fused_add_norm_fn = ( |
|
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn |
|
) |
|
if self.rcps: |
|
|
|
hidden_states_fwd = fused_add_norm_fn( |
|
hidden_states[..., : hidden_states.shape[-1] // 2], |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual[..., : hidden_states.shape[-1] // 2], |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
hidden_states_rc = fused_add_norm_fn( |
|
hidden_states[..., hidden_states.shape[-1] // 2 :].flip( |
|
dims=[-2, -1] |
|
), |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual[..., hidden_states.shape[-1] // 2 :].flip( |
|
dims=[-2, -1] |
|
), |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
hidden_states = torch.cat( |
|
[hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1 |
|
) |
|
else: |
|
|
|
hidden_states = fused_add_norm_fn( |
|
hidden_states, |
|
self.norm_f.weight, |
|
self.norm_f.bias, |
|
eps=self.norm_f.eps, |
|
residual=residual, |
|
prenorm=False, |
|
residual_in_fp32=self.residual_in_fp32, |
|
) |
|
if output_hidden_states: |
|
all_hidden_states.append(hidden_states) |
|
return hidden_states, all_hidden_states |
|
|
|
|
|
def cross_entropy(logits, y, ignore_index=-100): |
|
"""Cross entropy loss.""" |
|
logits = logits.view(-1, logits.shape[-1]) |
|
y = y.view(-1) |
|
return F.cross_entropy(logits, y, ignore_index=ignore_index) |
|
|
|
|
|
def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100): |
|
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome).""" |
|
logits = logits.view(-1, logits.shape[-1]) |
|
y = y.view(-1) |
|
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none") |
|
loss_weights = loss_weights.view(-1) |
|
loss_weights[y == ignore_index] = 0.0 |
|
|
|
return (ce * (loss_weights / loss_weights.sum())).sum() |
|
|
|
|
|
class CaduceusPreTrainedModel(PreTrainedModel): |
|
"""PreTrainedModel wrapper for Caduceus backbone.""" |
|
|
|
config_class = CaduceusConfig |
|
base_model_prefix = "caduceus" |
|
supports_gradient_checkpointing = False |
|
_no_split_modules = ["BiMambaWrapper"] |
|
|
|
def _init_weights( |
|
self, |
|
module, |
|
initializer_range=0.02, |
|
**kwargs, |
|
): |
|
"""Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py""" |
|
|
|
n_layer = self.config.n_layer |
|
initialized_cfg = ( |
|
self.config.initializer_cfg |
|
if self.config.initializer_cfg is not None |
|
else {} |
|
) |
|
rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True) |
|
initializer_range = initialized_cfg.get("initializer_range", initializer_range) |
|
n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1) |
|
|
|
if isinstance(module, nn.Linear): |
|
if module.bias is not None: |
|
if not getattr(module.bias, "_no_reinit", False): |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.normal_(module.weight, std=initializer_range) |
|
|
|
if rescale_prenorm_residual: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for name, p in module.named_parameters(): |
|
if name in ["out_proj.weight", "fc2.weight"]: |
|
|
|
|
|
|
|
|
|
nn.init.kaiming_uniform_(p, a=math.sqrt(5)) |
|
with torch.no_grad(): |
|
p /= math.sqrt(n_residuals_per_layer * n_layer) |
|
|
|
class AxialCaduceusPreTrainedModel(PreTrainedModel): |
|
"""PreTrainedModel wrapper for Caduceus backbone.""" |
|
|
|
config_class = AxialCaduceusConfig |
|
base_model_prefix = "axial_caduceus" |
|
supports_gradient_checkpointing = False |
|
_no_split_modules = ["BiMambaWrapper"] |
|
|
|
def _init_weights( |
|
self, |
|
module, |
|
initializer_range=0.02, |
|
**kwargs, |
|
): |
|
"""Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py""" |
|
|
|
n_layer = self.config.n_layer |
|
initialized_cfg = ( |
|
self.config.initializer_cfg |
|
if self.config.initializer_cfg is not None |
|
else {} |
|
) |
|
rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True) |
|
initializer_range = initialized_cfg.get("initializer_range", initializer_range) |
|
n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1) |
|
|
|
if isinstance(module, nn.Linear): |
|
if module.bias is not None: |
|
if not getattr(module.bias, "_no_reinit", False): |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.normal_(module.weight, std=initializer_range) |
|
|
|
if rescale_prenorm_residual: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for name, p in module.named_parameters(): |
|
if name in ["out_proj.weight", "fc2.weight"]: |
|
|
|
|
|
|
|
|
|
nn.init.kaiming_uniform_(p, a=math.sqrt(5)) |
|
with torch.no_grad(): |
|
p /= math.sqrt(n_residuals_per_layer * n_layer) |
|
|
|
|
|
|
|
class Caduceus(CaduceusPreTrainedModel): |
|
"""Caduceus model that can be instantiated using HF patterns.""" |
|
|
|
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): |
|
super().__init__(config) |
|
|
|
if config.rcps: |
|
assert ( |
|
config.complement_map is not None |
|
), "Complement map must be provided for RCPS." |
|
|
|
|
|
if config.vocab_size % config.pad_vocab_size_multiple != 0: |
|
config.vocab_size += config.pad_vocab_size_multiple - ( |
|
config.vocab_size % config.pad_vocab_size_multiple |
|
) |
|
if config.complement_map is not None and config.vocab_size > len( |
|
config.complement_map |
|
): |
|
for i in range(len(config.complement_map), config.vocab_size): |
|
config.complement_map[i] = i |
|
|
|
self.config = config |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: |
|
"""HF-compatible forward method.""" |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
hidden_states, all_hidden_states = self.backbone( |
|
input_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
if return_dict: |
|
return BaseModelOutputWithNoAttention( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states if output_hidden_states else None, |
|
) |
|
elif output_hidden_states: |
|
return hidden_states, all_hidden_states |
|
else: |
|
return hidden_states |
|
|
|
|
|
class AxialCaduceus(AxialCaduceusPreTrainedModel): |
|
"""Caduceus model that can be instantiated using HF patterns.""" |
|
|
|
def __init__(self, config: AxialCaduceusConfig, device=None, dtype=None, **kwargs): |
|
super().__init__(config) |
|
|
|
if config.rcps: |
|
assert ( |
|
config.complement_map is not None |
|
), "Complement map must be provided for RCPS." |
|
|
|
|
|
if config.vocab_size % config.pad_vocab_size_multiple != 0: |
|
config.vocab_size += config.pad_vocab_size_multiple - ( |
|
config.vocab_size % config.pad_vocab_size_multiple |
|
) |
|
if config.complement_map is not None and config.vocab_size > len( |
|
config.complement_map |
|
): |
|
for i in range(len(config.complement_map), config.vocab_size): |
|
config.complement_map[i] = i |
|
|
|
self.config = config |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
self.backbone = AxialCaduceusMixerModel(config, **factory_kwargs, **kwargs) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
input_positions: Optional[torch.LongTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: |
|
"""HF-compatible forward method.""" |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
hidden_states, all_hidden_states = self.backbone( |
|
input_ids, |
|
inputs_embeds=inputs_embeds, |
|
input_positions=input_positions, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
if return_dict: |
|
return BaseModelOutputWithNoAttention( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states if output_hidden_states else None, |
|
) |
|
elif output_hidden_states: |
|
return hidden_states, all_hidden_states |
|
else: |
|
return hidden_states |
|
|
|
|
|
class MixedAxialCaduceus(CaduceusPreTrainedModel): |
|
"""Mixed Caduceus/Attention model that can be instantiated using HF patterns.""" |
|
|
|
def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs): |
|
super().__init__(config) |
|
|
|
if config.rcps: |
|
assert ( |
|
config.complement_map is not None |
|
), "Complement map must be provided for RCPS." |
|
|
|
|
|
if config.vocab_size % config.pad_vocab_size_multiple != 0: |
|
config.vocab_size += config.pad_vocab_size_multiple - ( |
|
config.vocab_size % config.pad_vocab_size_multiple |
|
) |
|
if config.complement_map is not None and config.vocab_size > len( |
|
config.complement_map |
|
): |
|
for i in range(len(config.complement_map), config.vocab_size): |
|
config.complement_map[i] = i |
|
|
|
self.config = config |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
self.backbone = MixedAxialCaduceusMixerModel(config, **factory_kwargs, **kwargs) |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: |
|
"""HF-compatible forward method.""" |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
hidden_states, all_hidden_states = self.backbone( |
|
input_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
if return_dict: |
|
return BaseModelOutputWithNoAttention( |
|
last_hidden_state=hidden_states, |
|
hidden_states=all_hidden_states if output_hidden_states else None, |
|
) |
|
elif output_hidden_states: |
|
return hidden_states, all_hidden_states |
|
else: |
|
return hidden_states |
|
|
|
|
|
class CaduceusForMaskedLM(CaduceusPreTrainedModel): |
|
"""HF-compatible Caduceus model for masked language modeling.""" |
|
|
|
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): |
|
super().__init__(config, **kwargs) |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs) |
|
if config.rcps: |
|
self.lm_head = RCPSLMHead( |
|
complement_map=self.config.complement_map, |
|
vocab_size=self.config.vocab_size, |
|
true_dim=config.d_model, |
|
dtype=dtype, |
|
) |
|
else: |
|
self.lm_head = nn.Linear( |
|
config.d_model, |
|
self.config.vocab_size, |
|
bias=False, |
|
**factory_kwargs, |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.caduceus.backbone.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
if self.config.rcps: |
|
raise NotImplementedError( |
|
"Setting input embeddings for RCPS LM is not supported." |
|
) |
|
self.caduceus.backbone.embeddings.word_embeddings = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
"""Overrides output embeddings.""" |
|
if self.config.rcps: |
|
raise NotImplementedError( |
|
"Setting output embeddings for RCPS LM is not supported." |
|
) |
|
self.lm_head = new_embeddings |
|
|
|
def tie_weights(self): |
|
"""Tie weights, accounting for RCPS.""" |
|
if self.config.rcps: |
|
self.lm_head.set_weight(self.get_input_embeddings().weight) |
|
else: |
|
super().tie_weights() |
|
|
|
def get_decoder(self): |
|
"""Get decoder (backbone) for the model.""" |
|
return self.caduceus |
|
|
|
def set_decoder(self, decoder): |
|
"""Set decoder (backbone) for the model.""" |
|
self.caduceus = decoder |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
loss_weights: Optional[torch.FloatTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, MaskedLMOutput]: |
|
"""HF-compatible forward method.""" |
|
|
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
outputs = self.caduceus( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
logits = logits.float() |
|
|
|
loss = None |
|
if labels is not None: |
|
if loss_weights is not None: |
|
loss = weighted_cross_entropy( |
|
logits, labels, loss_weights, ignore_index=self.config.pad_token_id |
|
) |
|
else: |
|
loss = cross_entropy( |
|
logits, labels, ignore_index=self.config.pad_token_id |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
) |
|
|
|
|
|
class AxialCaduceusForMaskedLM(AxialCaduceusPreTrainedModel): |
|
"""HF-compatible Caduceus model for masked language modeling.""" |
|
|
|
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): |
|
super().__init__(config, **kwargs) |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
self.caduceus = AxialCaduceus(config, **factory_kwargs, **kwargs) |
|
if config.rcps: |
|
self.lm_head = RCPSLMHead( |
|
complement_map=self.config.complement_map, |
|
vocab_size=self.config.vocab_size, |
|
true_dim=config.d_model, |
|
dtype=dtype, |
|
) |
|
else: |
|
self.lm_head = nn.Linear( |
|
config.d_model, |
|
self.config.vocab_size, |
|
bias=False, |
|
**factory_kwargs, |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.caduceus.backbone.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
if self.config.rcps: |
|
raise NotImplementedError( |
|
"Setting input embeddings for RCPS LM is not supported." |
|
) |
|
self.caduceus.backbone.embeddings.word_embeddings = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
"""Overrides output embeddings.""" |
|
if self.config.rcps: |
|
raise NotImplementedError( |
|
"Setting output embeddings for RCPS LM is not supported." |
|
) |
|
self.lm_head = new_embeddings |
|
|
|
def tie_weights(self): |
|
"""Tie weights, accounting for RCPS.""" |
|
if self.config.rcps: |
|
self.lm_head.set_weight(self.get_input_embeddings().weight) |
|
else: |
|
super().tie_weights() |
|
|
|
def get_decoder(self): |
|
"""Get decoder (backbone) for the model.""" |
|
return self.caduceus |
|
|
|
def set_decoder(self, decoder): |
|
"""Set decoder (backbone) for the model.""" |
|
self.caduceus = decoder |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
input_positions: Optional[torch.LongTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
loss_weights: Optional[torch.FloatTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, MaskedLMOutput]: |
|
"""HF-compatible forward method.""" |
|
|
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
outputs = self.caduceus( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
input_positions=input_positions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
logits = logits.float() |
|
|
|
loss = None |
|
if labels is not None: |
|
if loss_weights is not None: |
|
loss = weighted_cross_entropy( |
|
logits, labels, loss_weights, ignore_index=self.config.pad_token_id |
|
) |
|
else: |
|
loss = cross_entropy( |
|
logits, labels, ignore_index=self.config.pad_token_id |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
) |
|
|
|
|
|
class MixedAxialCaduceusForMaskedLM(CaduceusPreTrainedModel): |
|
"""HF-compatible Caduceus model for masked language modeling.""" |
|
|
|
def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs): |
|
super().__init__(config, **kwargs) |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
self.caduceus = MixedAxialCaduceus(config, **factory_kwargs, **kwargs) |
|
if config.rcps: |
|
self.lm_head = RCPSLMHead( |
|
complement_map=self.config.complement_map, |
|
vocab_size=self.config.vocab_size, |
|
true_dim=config.d_model, |
|
dtype=dtype, |
|
) |
|
else: |
|
self.lm_head = nn.Linear( |
|
config.d_model, |
|
self.config.vocab_size, |
|
bias=False, |
|
**factory_kwargs, |
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.caduceus.backbone.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
if self.config.rcps: |
|
raise NotImplementedError( |
|
"Setting input embeddings for RCPS LM is not supported." |
|
) |
|
self.caduceus.backbone.embeddings.word_embeddings = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
"""Overrides output embeddings.""" |
|
if self.config.rcps: |
|
raise NotImplementedError( |
|
"Setting output embeddings for RCPS LM is not supported." |
|
) |
|
self.lm_head = new_embeddings |
|
|
|
def tie_weights(self): |
|
"""Tie weights, accounting for RCPS.""" |
|
if self.config.rcps: |
|
self.lm_head.set_weight(self.get_input_embeddings().weight) |
|
else: |
|
super().tie_weights() |
|
|
|
def get_decoder(self): |
|
"""Get decoder (backbone) for the model.""" |
|
return self.caduceus |
|
|
|
def set_decoder(self, decoder): |
|
"""Set decoder (backbone) for the model.""" |
|
self.caduceus = decoder |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
loss_weights: Optional[torch.FloatTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, MaskedLMOutput]: |
|
"""HF-compatible forward method.""" |
|
|
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
outputs = self.caduceus( |
|
input_ids=input_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
logits = logits.float() |
|
|
|
loss = None |
|
if labels is not None: |
|
if loss_weights is not None: |
|
loss = weighted_cross_entropy( |
|
logits, labels, loss_weights, ignore_index=self.config.pad_token_id |
|
) |
|
else: |
|
loss = cross_entropy( |
|
logits, labels, ignore_index=self.config.pad_token_id |
|
) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return MaskedLMOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
) |
|
|
|
|
|
class CaduceusForSequenceClassification(CaduceusPreTrainedModel): |
|
def __init__( |
|
self, |
|
config: CaduceusConfig, |
|
pooling_strategy: str = "mean", |
|
conjoin_train: bool = False, |
|
conjoin_eval: bool = False, |
|
device=None, |
|
dtype=None, |
|
**kwargs, |
|
): |
|
super().__init__(config, **kwargs) |
|
if pooling_strategy not in ["mean", "max", "first", "last"]: |
|
raise NotImplementedError( |
|
f"Pooling strategy `{pooling_strategy}` not implemented." |
|
) |
|
self.pooling_strategy = pooling_strategy |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
self.num_labels = kwargs.get("num_labels", config.num_labels) |
|
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs) |
|
self.score = nn.Linear(config.d_model, self.num_labels, bias=False) |
|
|
|
self.conjoin_train = conjoin_train |
|
self.conjoin_eval = conjoin_eval |
|
|
|
|
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.caduceus.backbone.embeddings.word_embeddings |
|
|
|
def set_input_embeddings(self, value): |
|
if self.config.rcps: |
|
raise NotImplementedError( |
|
"Setting input embeddings for RCPS LM is not supported." |
|
) |
|
self.caduceus.backbone.embeddings.word_embeddings = value |
|
|
|
def pool_hidden_states(self, hidden_states, sequence_length_dim=1): |
|
"""Pools hidden states along sequence length dimension.""" |
|
if ( |
|
self.pooling_strategy == "mean" |
|
): |
|
return hidden_states.mean(dim=sequence_length_dim) |
|
if ( |
|
self.pooling_strategy == "max" |
|
): |
|
return hidden_states.max(dim=sequence_length_dim).values |
|
if ( |
|
self.pooling_strategy == "last" |
|
): |
|
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[ |
|
-1, ... |
|
] |
|
if ( |
|
self.pooling_strategy == "first" |
|
): |
|
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...] |
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple, SequenceClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
|
|
|
|
if self.config.rcps: |
|
transformer_outputs = self.caduceus( |
|
input_ids, |
|
inputs_embeds=inputs_embeds, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = torch.stack( |
|
[ |
|
transformer_outputs[0][..., : self.config.d_model], |
|
torch.flip( |
|
transformer_outputs[0][..., self.config.d_model :], dims=[1, 2] |
|
), |
|
], |
|
dim=-1, |
|
) |
|
elif self.conjoin_train or ( |
|
self.conjoin_eval and not self.training |
|
): |
|
assert input_ids is not None, "`input_ids` must be provided for conjoining." |
|
assert ( |
|
input_ids.ndim == 3 |
|
), "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands." |
|
transformer_outputs = self.caduceus( |
|
input_ids[..., 0], |
|
inputs_embeds=None, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
transformer_outputs_rc = self.caduceus( |
|
input_ids[..., 1], |
|
inputs_embeds=None, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
hidden_states = torch.stack( |
|
[transformer_outputs[0], transformer_outputs_rc[0]], dim=-1 |
|
) |
|
else: |
|
transformer_outputs = self.caduceus( |
|
input_ids, |
|
inputs_embeds=None, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
hidden_states = transformer_outputs[0] |
|
|
|
|
|
pooled_hidden_states = self.pool_hidden_states(hidden_states) |
|
|
|
if ( |
|
hidden_states.ndim == 4 |
|
): |
|
logits_fwd = self.score(pooled_hidden_states[..., 0]) |
|
logits_rc = self.score(pooled_hidden_states[..., 1]) |
|
logits = (logits_fwd + logits_rc) / 2 |
|
else: |
|
logits = self.score(pooled_hidden_states) |
|
|
|
loss = None |
|
if labels is not None: |
|
labels = labels.to(logits.device) |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and ( |
|
labels.dtype == torch.long or labels.dtype == torch.int |
|
): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
if self.num_labels == 1: |
|
loss = F.mse_loss(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = F.mse_loss(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss = F.cross_entropy( |
|
logits.view(-1, self.num_labels), labels.view(-1) |
|
) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss = F.binary_cross_entropy_with_logits(logits, labels) |
|
if not return_dict: |
|
output = (logits,) + transformer_outputs[1:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=transformer_outputs.hidden_states, |
|
) |
|
|