axial_caducues_1200 / modeling_caduceus.py
emarro's picture
Upload AxialCaduceusForMaskedLM
0215062 verified
"""Caduceus model for Hugging Face.
"""
import math
from functools import partial
from typing import Optional, Tuple, Union
import torch
#from mamba_ssm.modules.mamba_simple import Mamba, Block
#from mamba_ssm.modules import Block
from mamba_ssm import Mamba, Mamba2
from mamba_ssm.modules.block import Block
from mamba_ssm.modules.mlp import GatedMLP
from torch import nn
from torch.nn import functional as F
from torch.nn.parallel import parallel_apply
from transformers import PreTrainedModel
from transformers.modeling_outputs import (
BaseModelOutputWithNoAttention,
MaskedLMOutput,
SequenceClassifierOutput,
)
try:
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
from .configuration_caduceus import CaduceusConfig, MixedCaduceusConfig, AxialCaduceusConfig
from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock
#from .esm_repo.esm.axial_attention import RowSelfAttention
#from .esm_repo.esm.modules import NormalizedResidualBlock
def sinusoidal_encoding(positions: torch.Tensor, d_model: int, device=None, dtype=None):
"""
from https://github.com/wzlxjtu/PositionalEncoding2D
:param d_model: dimension of the model (d model)
:param positions: Tensor of the input positions [B, L]
:return: length*d_model position matrix
"""
factory_kwargs = {"device": device, "dtype": dtype}
if d_model % 2 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dim (got dim={:d})".format(d_model))
B, L = positions.size()
pe = torch.zeros(B, L, d_model, **factory_kwargs) # [B, L, D}
# position = torch.arange(0, length).unsqueeze(1) #[L, 1]
position = positions.unsqueeze(-1) # [B,L,1]
div_term = torch.exp((torch.arange(0, d_model, 2, device=position.device, dtype=torch.float) *
-(math.log(10000.0) / d_model)))
pe[:, :, 0::2] = torch.sin(position.float() * div_term)
pe[:, :, 1::2] = torch.cos(position.float() * div_term)
pe = pe.to(**factory_kwargs)
return pe
def create_block(
d_model,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
bidirectional=True,
bidirectional_strategy="add",
bidirectional_weight_tie=True,
rcps=False,
device=None,
dtype=None,
):
"""Create Caduceus block.
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
"""
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
bidirectional_kwargs = {
"bidirectional": bidirectional,
"bidirectional_strategy": bidirectional_strategy,
"bidirectional_weight_tie": bidirectional_weight_tie,
}
mixer_cls = partial(
BiMambaWrapper,
layer_idx=layer_idx,
**ssm_cfg,
**bidirectional_kwargs,
**factory_kwargs,
)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block_cls = RCPSMambaBlock if rcps else Block
d_intermediate=0
if d_intermediate == 0:
mlp_cls = nn.Identity
else:
mlp_cls = partial(
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
)
block = block_cls(
dim=d_model,
mixer_cls=mixer_cls,
mlp_cls=mlp_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
def create_axial_block(
d_model,
d_intermediate,
use_mamba2,
axis,
ssm_cfg=None,
norm_epsilon=1e-5,
rms_norm=False,
residual_in_fp32=False,
fused_add_norm=False,
layer_idx=None,
bidirectional=True,
bidirectional_strategy="add",
bidirectional_weight_tie=True,
rcps=False,
device=None,
dtype=None,
):
"""Create an axial Caduceus block composed of two AxialCaduceus blocks, one for row and one for columns.
Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py
"""
if ssm_cfg is None:
ssm_cfg = {}
factory_kwargs = {"device": device, "dtype": dtype}
bidirectional_kwargs = {
"bidirectional": bidirectional,
"bidirectional_strategy": bidirectional_strategy,
"bidirectional_weight_tie": bidirectional_weight_tie,
}
#mixer_cls = partial(
# Mamba2 if ssm_layer == "Mamba2" else Mamba,
# layer_idx=layer_idx,
# **ssm_cfg,
# **factory_kwargs
#)
mixer_cls = partial(
AxialBiMambaWrapper,
use_mamba2=use_mamba2,
axis=axis,
layer_idx=layer_idx,
**ssm_cfg,
**bidirectional_kwargs,
**factory_kwargs,
)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block_cls = RCPSMambaBlock if rcps else Block
if d_intermediate == 0:
mlp_cls = nn.Identity
else:
mlp_cls = partial(
GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs
)
block = block_cls(
dim=d_model,
mixer_cls=mixer_cls,
mlp_cls=mlp_cls,
norm_cls=norm_cls,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32,
)
block.layer_idx = layer_idx
return block
def create_attention_block(
d_model: int,
n_heads: int,
attention_dropout: float,
block_dropout: float,
layer_idx=None,
device=None,
dtype=None,
):
"""Create an RowAttention block from MSATransformer."""
raise NotImplementedError()
# factory_kwargs = {"device": device, "dtype": dtype}
# layer_cls = RowSelfAttention(
# embed_dim=d_model, num_heads=n_heads, dropout=attention_dropout
# )
# block = NormalizedResidualBlock(
# layer=layer_cls, embedding_dim=d_model, dropout=block_dropout
# ) # Wraps attention with residual connection, layer norm, and drop out. NOTE: No mixer in this block
# block = block.to(device)
# block.layer_idx = layer_idx
# return block
class BiMambaWrapper(nn.Module):
"""Thin wrapper around Mamba to support bi-directionality."""
def __init__(
self,
d_model: int,
bidirectional: bool = True,
bidirectional_strategy: Optional[str] = "add",
bidirectional_weight_tie: bool = True,
**mamba_kwargs,
):
super().__init__()
if bidirectional and bidirectional_strategy is None:
bidirectional_strategy = "add" # Default strategy: `add`
if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
raise NotImplementedError(
f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!"
)
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.mamba_fwd = Mamba(d_model=d_model, **mamba_kwargs)
if bidirectional:
self.mamba_rev = Mamba(d_model=d_model, **mamba_kwargs)
if (
bidirectional_weight_tie
): # Tie in and out projections (where most of param count lies)
self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
else:
self.mamba_rev = None
def forward(self, hidden_states, inference_params=None):
"""Bidirectional-enabled forward pass
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
out = self.mamba_fwd(hidden_states, inference_params=inference_params)
if self.bidirectional:
out_rev = self.mamba_rev(
hidden_states.flip(
dims=(1,)
), # Flip along the sequence length dimension
inference_params=inference_params,
).flip(dims=(1,)) # Flip back for combining with forward hidden states
if self.bidirectional_strategy == "add":
out = out + out_rev
elif self.bidirectional_strategy == "ew_multiply":
out = out * out_rev
else:
raise NotImplementedError(
f"`{self.bidirectional_strategy}` for bi-directionality not implemented!"
)
return out
class AxialBiMambaWrapper(nn.Module):
"""Thin wrapper around BiMamba to support running and aggregating over rows.
axis=1 for RowMamba, axis=2 for column Mamba
"""
def __init__(
self,
d_model: int,
use_mamba2: bool,
bidirectional: bool = True,
bidirectional_strategy: Optional[str] = "add",
bidirectional_weight_tie: bool = True,
axis: int = 1,
**mamba_kwargs,
):
super().__init__()
if bidirectional and bidirectional_strategy is None:
bidirectional_strategy = "add" # Default strategy: `add`
if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]:
raise NotImplementedError(
f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!"
)
self.bidirectional = bidirectional
self.bidirectional_strategy = bidirectional_strategy
self.mamba_fwd = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs)
self.axis = axis
if bidirectional:
self.mamba_rev = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs)
if (
bidirectional_weight_tie
): # Tie in and out projections (where most of param count lies)
self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight
self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias
self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight
self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias
else:
self.mamba_rev = None
def forward(self, hidden_states, inference_params=None):
"""Bidirectional-enabled forward pass
hidden_states: (B, R, C, D)
Returns: same shape as hidden_states
"""
def apply_mamba(x):
out = self.mamba_fwd(x, inference_params=inference_params)
if self.bidirectional:
out_rev = self.mamba_rev(
x.flip(
dims=(1,)
), # Flip along the sequence length dimension
inference_params=inference_params,
).flip(dims=(1,)) # Flip back for combining with forward hidden states
if self.bidirectional_strategy == "add":
out = out + out_rev
elif self.bidirectional_strategy == "ew_multiply":
out = out * out_rev
else:
raise NotImplementedError(
f"`{self.bidirectional_strategy}` for bi-directionality not implemented!"
)
return out
batch, rows, columns, hidden_dim = hidden_states.size()
if self.axis == 1: # row mamba
hidden_states = hidden_states.permute(1, 0, 2, 3)
axis_len = rows
elif self.axis == 2:
hidden_states = hidden_states.permute(2, 0, 1, 3)
axis_len = columns
outs = []
## parllel
#outs = parallel_apply([apply_mamba for _ in range(axis_len)], hidden_states.unbind(0))
## reshape
outs = apply_mamba(hidden_states.reshape(axis_len * batch, -1, hidden_dim))
out = outs.reshape(axis_len, batch, -1, hidden_dim)
### forlop
#for axis_idx in range(axis_len):
#tmp_hidden_states = hidden_states[axis_idx, ...]
#out = apply_mamba(tmp_hidden_states)
#outs.append(out)
#out = torch.stack(outs, dim=0)
if self.axis == 1: # row mamba
out = out.permute(1, 0, 2, 3)
elif self.axis == 2: # [C, B, R, D]
out = out.permute(1, 2, 0, 3)
return out
class CaduceusEmbeddings(nn.Module):
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
if config.rcps:
self.word_embeddings = RCPSEmbedding(
config.vocab_size,
config.d_model,
config.complement_map,
**factory_kwargs,
)
else:
self.word_embeddings = nn.Embedding(
config.vocab_size, config.d_model, **factory_kwargs
)
def forward(self, input_ids):
"""
input_ids: (batch, seqlen)
"""
return self.word_embeddings(input_ids)
class CaduceusMixerModel(nn.Module):
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.fused_add_norm = config.fused_add_norm
self.rcps = config.rcps
self.residual_in_fp32 = config.residual_in_fp32
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
# Mamba changes the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
if config.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
self.layers = nn.ModuleList(
[
create_block(
config.d_model,
ssm_cfg=config.ssm_cfg,
norm_epsilon=config.norm_epsilon,
rms_norm=config.rms_norm,
residual_in_fp32=config.residual_in_fp32,
fused_add_norm=config.fused_add_norm,
layer_idx=i,
bidirectional=config.bidirectional,
bidirectional_strategy=config.bidirectional_strategy,
bidirectional_weight_tie=config.bidirectional_weight_tie,
rcps=config.rcps,
**factory_kwargs,
)
for i in range(config.n_layer)
]
)
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
config.d_model, eps=config.norm_epsilon, **factory_kwargs
)
self.norm_f = (
norm_f
if (config.fused_add_norm or not config.rcps)
else RCPSAddNormWrapper(norm_f)
)
def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
"""Mixer forward."""
all_hidden_states = []
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids)
residual = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states.append(hidden_states)
# TODO: Add support for gradient checkpointing
hidden_states, residual = layer(
hidden_states, residual, inference_params=None
)
if not self.fused_add_norm:
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states = self.norm_f(
hidden_states, residual=residual, prenorm=False
)
else:
residual = (
(hidden_states + residual)
if residual is not None
else hidden_states
)
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
fused_add_norm_fn = (
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
)
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states_fwd = fused_add_norm_fn(
hidden_states[..., : hidden_states.shape[-1] // 2],
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., : hidden_states.shape[-1] // 2],
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states_rc = fused_add_norm_fn(
hidden_states[..., hidden_states.shape[-1] // 2 :].flip(
dims=[-2, -1]
),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., hidden_states.shape[-1] // 2 :].flip(
dims=[-2, -1]
),
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = torch.cat(
[hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1
)
else:
# Set prenorm=False here since we don't need the residual
hidden_states = fused_add_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states
class AxialCaduceusMixerModel(nn.Module):
def __init__(
self,
config: CaduceusConfig,
device=None,
dtype=None,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.fused_add_norm = config.fused_add_norm
self.rcps = config.rcps
self.residual_in_fp32 = config.residual_in_fp32
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
self.pos_embeddings = None
self.add_pos = False
if config.pos_embeddings == 'Linear':
self.add_pos = True
self.pos_embeddings = nn.Linear(in_features=1, out_features=config.d_model, **factory_kwargs)
elif config.pos_embeddings == 'Sinusoidal':
self.pos_embeddings = partial(sinusoidal_encoding, d_model=config.d_model, **factory_kwargs)
# Mamba changes the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
if config.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
row_first = 0 #assume col ssm first
if config.row_first: #row first
row_first = 1
self.layers = nn.ModuleList(
[
create_axial_block(
d_model=config.d_model,
d_intermediate=config.d_intermediate,
use_mamba2=config.use_mamba2,
axis=((i + row_first) % 2) + 1, # (i%2) + 1 for columns first
ssm_cfg=config.ssm_cfg,
norm_epsilon=config.norm_epsilon,
rms_norm=config.rms_norm,
residual_in_fp32=config.residual_in_fp32,
fused_add_norm=config.fused_add_norm,
layer_idx=i,
bidirectional=config.bidirectional,
bidirectional_strategy=config.bidirectional_strategy,
bidirectional_weight_tie=config.bidirectional_weight_tie,
rcps=config.rcps,
**factory_kwargs,
)
for i in range(config.n_layer * 2)
]
)
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
config.d_model, eps=config.norm_epsilon, **factory_kwargs
)
self.norm_f = (
norm_f
if (config.fused_add_norm or not config.rcps)
else RCPSAddNormWrapper(norm_f)
)
def forward(self, input_ids, inputs_embeds=None, input_positions=None, output_hidden_states=False):
"""Mixer forward."""
all_hidden_states = []
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids)
if self.pos_embeddings is not None:
if self.add_pos:
pos_embedding = self.pos_embeddings(input_positions[...,None]) #[B, L, D]
hidden_states = torch.cat([pos_embedding[:,None, ...], hidden_states], dim=1)
else:
p_B, p_L = input_positions.size()
B, R, L, D = hidden_states.size()
assert p_B == B
assert p_L == L
pos_embedding = self.pos_embeddings(positions=input_positions)[:,None, ...] # [B, 1, L, D]
hidden_states += pos_embedding
residual = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states.append(hidden_states)
# TODO: Add support for gradient checkpointing
hidden_states, residual = layer(
hidden_states, residual, inference_params=None
)
if not self.fused_add_norm:
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states = self.norm_f(
hidden_states, residual=residual, prenorm=False
)
else:
residual = (
(hidden_states + residual)
if residual is not None
else hidden_states
)
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
fused_add_norm_fn = (
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
)
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states_fwd = fused_add_norm_fn(
hidden_states[..., : hidden_states.shape[-1] // 2],
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., : hidden_states.shape[-1] // 2],
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states_rc = fused_add_norm_fn(
hidden_states[..., hidden_states.shape[-1] // 2 :].flip(
dims=[-2, -1]
),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., hidden_states.shape[-1] // 2 :].flip(
dims=[-2, -1]
),
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = torch.cat(
[hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1
)
else:
# Set prenorm=False here since we don't need the residual
hidden_states = fused_add_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
if output_hidden_states:
all_hidden_states.append(hidden_states)
if self.pos_embeddings is not None and self.add_pos:
#removce the positional embeddings form the returned MSA
hidden_states = hidden_states[:,1:,...]
return hidden_states, all_hidden_states
class MixedAxialCaduceusMixerModel(nn.Module):
"""
A model that swtiches between Caducues and Standard attention mechanisms
"""
def __init__(
self,
config: MixedCaduceusConfig,
device=None,
dtype=None,
) -> None:
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.fused_add_norm = config.fused_add_norm
self.rcps = config.rcps
self.residual_in_fp32 = config.residual_in_fp32
self.embeddings = CaduceusEmbeddings(config, **factory_kwargs)
# Mamba changes the order of residual and layer norm:
# Instead of LN -> Attn / MLP -> Add, we do:
# Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
# the main branch (output of MLP / Mixer). The model definition is unchanged.
# This is for performance reason: we can fuse add + layer_norm.
if config.fused_add_norm:
if layer_norm_fn is None or rms_norm_fn is None:
raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")
layers = []
for i in range(config.n_layer * 2):
axis = ((i + 1) % 2) + 1 # 1 for rows, 2 for columns, columns first.
block = None
if axis == 1:
block = create_attention_block(
d_model=config.attn_d_model,
n_heads=config.attn_n_heads,
attention_dropout=config.attn_attn_dropout,
block_dropout=config.attn_block_dropout,
layer_idx=i,
**factory_kwargs,
)
elif axis == 2:
block = create_axial_block(
d_model=config.d_model,
d_intermediate=config.d_intermediate,
use_mamba2=config.use_mamba2,
axis=axis, # always columns
ssm_cfg=config.ssm_cfg,
norm_epsilon=config.norm_epsilon,
rms_norm=config.rms_norm,
residual_in_fp32=config.residual_in_fp32,
fused_add_norm=config.fused_add_norm,
layer_idx=i,
bidirectional=config.bidirectional,
bidirectional_strategy=config.bidirectional_strategy,
bidirectional_weight_tie=config.bidirectional_weight_tie,
rcps=config.rcps,
**factory_kwargs,
)
layers.append(block)
self.layers = nn.ModuleList(layers)
norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)(
config.d_model, eps=config.norm_epsilon, **factory_kwargs
)
self.norm_f = (
norm_f
if (config.fused_add_norm or not config.rcps)
else RCPSAddNormWrapper(norm_f)
)
def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False):
"""Mixer forward."""
all_hidden_states = []
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids)
residual = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states.append(hidden_states)
# TODO: Add support for gradient checkpointing
hidden_states, residual = layer(
hidden_states, residual, inference_params=None
)
if not self.fused_add_norm:
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states = self.norm_f(
hidden_states, residual=residual, prenorm=False
)
else:
residual = (
(hidden_states + residual)
if residual is not None
else hidden_states
)
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
fused_add_norm_fn = (
rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
)
if self.rcps:
# Set prenorm=False here since we don't need the residual
hidden_states_fwd = fused_add_norm_fn(
hidden_states[..., : hidden_states.shape[-1] // 2],
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., : hidden_states.shape[-1] // 2],
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states_rc = fused_add_norm_fn(
hidden_states[..., hidden_states.shape[-1] // 2 :].flip(
dims=[-2, -1]
),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual[..., hidden_states.shape[-1] // 2 :].flip(
dims=[-2, -1]
),
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = torch.cat(
[hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1
)
else:
# Set prenorm=False here since we don't need the residual
hidden_states = fused_add_norm_fn(
hidden_states,
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
if output_hidden_states:
all_hidden_states.append(hidden_states)
return hidden_states, all_hidden_states
def cross_entropy(logits, y, ignore_index=-100):
"""Cross entropy loss."""
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
return F.cross_entropy(logits, y, ignore_index=ignore_index)
def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100):
"""Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome)."""
logits = logits.view(-1, logits.shape[-1])
y = y.view(-1)
ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none")
loss_weights = loss_weights.view(-1)
loss_weights[y == ignore_index] = 0.0
# TODO: Follows GPN implementation, but should we remove weight normalization?
return (ce * (loss_weights / loss_weights.sum())).sum()
class CaduceusPreTrainedModel(PreTrainedModel):
"""PreTrainedModel wrapper for Caduceus backbone."""
config_class = CaduceusConfig
base_model_prefix = "caduceus"
supports_gradient_checkpointing = False
_no_split_modules = ["BiMambaWrapper"]
def _init_weights(
self,
module,
initializer_range=0.02, # Now only used for embedding layer.
**kwargs,
):
"""Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
n_layer = self.config.n_layer
initialized_cfg = (
self.config.initializer_cfg
if self.config.initializer_cfg is not None
else {}
)
rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
initializer_range = initialized_cfg.get("initializer_range", initializer_range)
n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth.
# > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
# residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
class AxialCaduceusPreTrainedModel(PreTrainedModel):
"""PreTrainedModel wrapper for Caduceus backbone."""
config_class = AxialCaduceusConfig
base_model_prefix = "axial_caduceus"
supports_gradient_checkpointing = False
_no_split_modules = ["BiMambaWrapper"]
def _init_weights(
self,
module,
initializer_range=0.02, # Now only used for embedding layer.
**kwargs,
):
"""Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py"""
n_layer = self.config.n_layer
initialized_cfg = (
self.config.initializer_cfg
if self.config.initializer_cfg is not None
else {}
)
rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True)
initializer_range = initialized_cfg.get("initializer_range", initializer_range)
n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1)
if isinstance(module, nn.Linear):
if module.bias is not None:
if not getattr(module.bias, "_no_reinit", False):
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, std=initializer_range)
if rescale_prenorm_residual:
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
# > A modified initialization which accounts for the accumulation on the residual path with model depth.
# > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of
# residual layers.
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
#
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
for name, p in module.named_parameters():
if name in ["out_proj.weight", "fc2.weight"]:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
# We need to reinit p since this code could be called multiple times
# Having just p *= scale would repeatedly scale it down
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
with torch.no_grad():
p /= math.sqrt(n_residuals_per_layer * n_layer)
class Caduceus(CaduceusPreTrainedModel):
"""Caduceus model that can be instantiated using HF patterns."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config)
if config.rcps:
assert (
config.complement_map is not None
), "Complement map must be provided for RCPS."
# Adjust vocab size and complement maps if vocab padding is set.
if config.vocab_size % config.pad_vocab_size_multiple != 0:
config.vocab_size += config.pad_vocab_size_multiple - (
config.vocab_size % config.pad_vocab_size_multiple
)
if config.complement_map is not None and config.vocab_size > len(
config.complement_map
):
for i in range(len(config.complement_map), config.vocab_size):
config.complement_map[i] = i
self.config = config
factory_kwargs = {"device": device, "dtype": dtype}
self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
hidden_states, all_hidden_states = self.backbone(
input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
)
if return_dict:
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states if output_hidden_states else None,
)
elif output_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states
class AxialCaduceus(AxialCaduceusPreTrainedModel):
"""Caduceus model that can be instantiated using HF patterns."""
def __init__(self, config: AxialCaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config)
if config.rcps:
assert (
config.complement_map is not None
), "Complement map must be provided for RCPS."
# Adjust vocab size and complement maps if vocab padding is set.
if config.vocab_size % config.pad_vocab_size_multiple != 0:
config.vocab_size += config.pad_vocab_size_multiple - (
config.vocab_size % config.pad_vocab_size_multiple
)
if config.complement_map is not None and config.vocab_size > len(
config.complement_map
):
for i in range(len(config.complement_map), config.vocab_size):
config.complement_map[i] = i
self.config = config
factory_kwargs = {"device": device, "dtype": dtype}
self.backbone = AxialCaduceusMixerModel(config, **factory_kwargs, **kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
input_positions: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
hidden_states, all_hidden_states = self.backbone(
input_ids,
inputs_embeds=inputs_embeds,
input_positions=input_positions,
output_hidden_states=output_hidden_states,
)
if return_dict:
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states if output_hidden_states else None,
)
elif output_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states
class MixedAxialCaduceus(CaduceusPreTrainedModel):
"""Mixed Caduceus/Attention model that can be instantiated using HF patterns."""
def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config)
if config.rcps:
assert (
config.complement_map is not None
), "Complement map must be provided for RCPS."
# Adjust vocab size and complement maps if vocab padding is set.
if config.vocab_size % config.pad_vocab_size_multiple != 0:
config.vocab_size += config.pad_vocab_size_multiple - (
config.vocab_size % config.pad_vocab_size_multiple
)
if config.complement_map is not None and config.vocab_size > len(
config.complement_map
):
for i in range(len(config.complement_map), config.vocab_size):
config.complement_map[i] = i
self.config = config
factory_kwargs = {"device": device, "dtype": dtype}
self.backbone = MixedAxialCaduceusMixerModel(config, **factory_kwargs, **kwargs)
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
hidden_states, all_hidden_states = self.backbone(
input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
)
if return_dict:
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states if output_hidden_states else None,
)
elif output_hidden_states:
return hidden_states, all_hidden_states
else:
return hidden_states
class CaduceusForMaskedLM(CaduceusPreTrainedModel):
"""HF-compatible Caduceus model for masked language modeling."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config, **kwargs)
factory_kwargs = {"device": device, "dtype": dtype}
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
if config.rcps:
self.lm_head = RCPSLMHead(
complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
true_dim=config.d_model,
dtype=dtype,
)
else:
self.lm_head = nn.Linear(
config.d_model,
self.config.vocab_size, # Use caduceus config as it might have been updated
bias=False,
**factory_kwargs,
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError(
"Setting input embeddings for RCPS LM is not supported."
)
self.caduceus.backbone.embeddings.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Overrides output embeddings."""
if self.config.rcps:
raise NotImplementedError(
"Setting output embeddings for RCPS LM is not supported."
)
self.lm_head = new_embeddings
def tie_weights(self):
"""Tie weights, accounting for RCPS."""
if self.config.rcps:
self.lm_head.set_weight(self.get_input_embeddings().weight)
else:
super().tie_weights()
def get_decoder(self):
"""Get decoder (backbone) for the model."""
return self.caduceus
def set_decoder(self, decoder):
"""Set decoder (backbone) for the model."""
self.caduceus = decoder
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
loss_weights: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.caduceus(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
if loss_weights is not None:
loss = weighted_cross_entropy(
logits, labels, loss_weights, ignore_index=self.config.pad_token_id
)
else:
loss = cross_entropy(
logits, labels, ignore_index=self.config.pad_token_id
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class AxialCaduceusForMaskedLM(AxialCaduceusPreTrainedModel):
"""HF-compatible Caduceus model for masked language modeling."""
def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config, **kwargs)
factory_kwargs = {"device": device, "dtype": dtype}
self.caduceus = AxialCaduceus(config, **factory_kwargs, **kwargs)
if config.rcps:
self.lm_head = RCPSLMHead(
complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
true_dim=config.d_model,
dtype=dtype,
)
else:
self.lm_head = nn.Linear(
config.d_model,
self.config.vocab_size, # Use caduceus config as it might have been updated
bias=False,
**factory_kwargs,
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError(
"Setting input embeddings for RCPS LM is not supported."
)
self.caduceus.backbone.embeddings.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Overrides output embeddings."""
if self.config.rcps:
raise NotImplementedError(
"Setting output embeddings for RCPS LM is not supported."
)
self.lm_head = new_embeddings
def tie_weights(self):
"""Tie weights, accounting for RCPS."""
if self.config.rcps:
self.lm_head.set_weight(self.get_input_embeddings().weight)
else:
super().tie_weights()
def get_decoder(self):
"""Get decoder (backbone) for the model."""
return self.caduceus
def set_decoder(self, decoder):
"""Set decoder (backbone) for the model."""
self.caduceus = decoder
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
input_positions: Optional[torch.LongTensor] = None,
labels: Optional[torch.LongTensor] = None,
loss_weights: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.caduceus(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
input_positions=input_positions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
if loss_weights is not None:
loss = weighted_cross_entropy(
logits, labels, loss_weights, ignore_index=self.config.pad_token_id
)
else:
loss = cross_entropy(
logits, labels, ignore_index=self.config.pad_token_id
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class MixedAxialCaduceusForMaskedLM(CaduceusPreTrainedModel):
"""HF-compatible Caduceus model for masked language modeling."""
def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs):
super().__init__(config, **kwargs)
factory_kwargs = {"device": device, "dtype": dtype}
self.caduceus = MixedAxialCaduceus(config, **factory_kwargs, **kwargs)
if config.rcps:
self.lm_head = RCPSLMHead(
complement_map=self.config.complement_map, # Use caduceus config as it might have been updated
vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated
true_dim=config.d_model,
dtype=dtype,
)
else:
self.lm_head = nn.Linear(
config.d_model,
self.config.vocab_size, # Use caduceus config as it might have been updated
bias=False,
**factory_kwargs,
)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError(
"Setting input embeddings for RCPS LM is not supported."
)
self.caduceus.backbone.embeddings.word_embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
"""Overrides output embeddings."""
if self.config.rcps:
raise NotImplementedError(
"Setting output embeddings for RCPS LM is not supported."
)
self.lm_head = new_embeddings
def tie_weights(self):
"""Tie weights, accounting for RCPS."""
if self.config.rcps:
self.lm_head.set_weight(self.get_input_embeddings().weight)
else:
super().tie_weights()
def get_decoder(self):
"""Get decoder (backbone) for the model."""
return self.caduceus
def set_decoder(self, decoder):
"""Set decoder (backbone) for the model."""
self.caduceus = decoder
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
loss_weights: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, MaskedLMOutput]:
"""HF-compatible forward method."""
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.caduceus(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
if loss_weights is not None:
loss = weighted_cross_entropy(
logits, labels, loss_weights, ignore_index=self.config.pad_token_id
)
else:
loss = cross_entropy(
logits, labels, ignore_index=self.config.pad_token_id
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
)
class CaduceusForSequenceClassification(CaduceusPreTrainedModel):
def __init__(
self,
config: CaduceusConfig,
pooling_strategy: str = "mean",
conjoin_train: bool = False,
conjoin_eval: bool = False,
device=None,
dtype=None,
**kwargs,
):
super().__init__(config, **kwargs)
if pooling_strategy not in ["mean", "max", "first", "last"]:
raise NotImplementedError(
f"Pooling strategy `{pooling_strategy}` not implemented."
)
self.pooling_strategy = pooling_strategy
factory_kwargs = {"device": device, "dtype": dtype}
self.num_labels = kwargs.get("num_labels", config.num_labels)
self.caduceus = Caduceus(config, **factory_kwargs, **kwargs)
self.score = nn.Linear(config.d_model, self.num_labels, bias=False)
self.conjoin_train = conjoin_train
self.conjoin_eval = conjoin_eval
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.caduceus.backbone.embeddings.word_embeddings
def set_input_embeddings(self, value):
if self.config.rcps:
raise NotImplementedError(
"Setting input embeddings for RCPS LM is not supported."
)
self.caduceus.backbone.embeddings.word_embeddings = value
def pool_hidden_states(self, hidden_states, sequence_length_dim=1):
"""Pools hidden states along sequence length dimension."""
if (
self.pooling_strategy == "mean"
): # Mean pooling along sequence length dimension
return hidden_states.mean(dim=sequence_length_dim)
if (
self.pooling_strategy == "max"
): # Max pooling along sequence length dimension
return hidden_states.max(dim=sequence_length_dim).values
if (
self.pooling_strategy == "last"
): # Use embedding of last token in the sequence
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[
-1, ...
]
if (
self.pooling_strategy == "first"
): # Use embedding of first token in the sequence
return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...]
def forward(
self,
input_ids: torch.LongTensor = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Get hidden representations from the backbone
if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS
transformer_outputs = self.caduceus(
input_ids,
inputs_embeds=inputs_embeds,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = torch.stack(
[
transformer_outputs[0][..., : self.config.d_model],
torch.flip(
transformer_outputs[0][..., self.config.d_model :], dims=[1, 2]
),
],
dim=-1,
)
elif self.conjoin_train or (
self.conjoin_eval and not self.training
): # For conjoining / post-hoc conjoining
assert input_ids is not None, "`input_ids` must be provided for conjoining."
assert (
input_ids.ndim == 3
), "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands."
transformer_outputs = self.caduceus(
input_ids[..., 0],
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_outputs_rc = self.caduceus(
input_ids[..., 1],
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# Stack along channel dimension (dim=-1)
hidden_states = torch.stack(
[transformer_outputs[0], transformer_outputs_rc[0]], dim=-1
)
else:
transformer_outputs = self.caduceus(
input_ids,
inputs_embeds=None,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
# Pool and get logits
pooled_hidden_states = self.pool_hidden_states(hidden_states)
# Potentially run `score` twice (with parameters shared) for conjoining
if (
hidden_states.ndim == 4
): # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps
logits_fwd = self.score(pooled_hidden_states[..., 0])
logits_rc = self.score(pooled_hidden_states[..., 1])
logits = (logits_fwd + logits_rc) / 2
else:
logits = self.score(pooled_hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
if self.num_labels == 1:
loss = F.mse_loss(logits.squeeze(), labels.squeeze())
else:
loss = F.mse_loss(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss = F.cross_entropy(
logits.view(-1, self.num_labels), labels.view(-1)
)
elif self.config.problem_type == "multi_label_classification":
loss = F.binary_cross_entropy_with_logits(logits, labels)
if not return_dict:
output = (logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=transformer_outputs.hidden_states,
)