"""Caduceus model for Hugging Face. """ import math from functools import partial from typing import Optional, Tuple, Union import torch #from mamba_ssm.modules.mamba_simple import Mamba, Block #from mamba_ssm.modules import Block from mamba_ssm import Mamba, Mamba2 from mamba_ssm.modules.block import Block from mamba_ssm.modules.mlp import GatedMLP from torch import nn from torch.nn import functional as F from torch.nn.parallel import parallel_apply from transformers import PreTrainedModel from transformers.modeling_outputs import ( BaseModelOutputWithNoAttention, MaskedLMOutput, SequenceClassifierOutput, ) try: from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn except ImportError: RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None from .configuration_caduceus import CaduceusConfig, MixedCaduceusConfig, AxialCaduceusConfig from .modeling_rcps import RCPSAddNormWrapper, RCPSEmbedding, RCPSLMHead, RCPSMambaBlock #from .esm_repo.esm.axial_attention import RowSelfAttention #from .esm_repo.esm.modules import NormalizedResidualBlock def sinusoidal_encoding(positions: torch.Tensor, d_model: int, device=None, dtype=None): """ from https://github.com/wzlxjtu/PositionalEncoding2D :param d_model: dimension of the model (d model) :param positions: Tensor of the input positions [B, L] :return: length*d_model position matrix """ factory_kwargs = {"device": device, "dtype": dtype} if d_model % 2 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dim (got dim={:d})".format(d_model)) B, L = positions.size() pe = torch.zeros(B, L, d_model, **factory_kwargs) # [B, L, D} # position = torch.arange(0, length).unsqueeze(1) #[L, 1] position = positions.unsqueeze(-1) # [B,L,1] div_term = torch.exp((torch.arange(0, d_model, 2, device=position.device, dtype=torch.float) * -(math.log(10000.0) / d_model))) pe[:, :, 0::2] = torch.sin(position.float() * div_term) pe[:, :, 1::2] = torch.cos(position.float() * div_term) pe = pe.to(**factory_kwargs) return pe def create_block( d_model, ssm_cfg=None, norm_epsilon=1e-5, rms_norm=False, residual_in_fp32=False, fused_add_norm=False, layer_idx=None, bidirectional=True, bidirectional_strategy="add", bidirectional_weight_tie=True, rcps=False, device=None, dtype=None, ): """Create Caduceus block. Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py """ if ssm_cfg is None: ssm_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} bidirectional_kwargs = { "bidirectional": bidirectional, "bidirectional_strategy": bidirectional_strategy, "bidirectional_weight_tie": bidirectional_weight_tie, } mixer_cls = partial( BiMambaWrapper, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs, ) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) block_cls = RCPSMambaBlock if rcps else Block d_intermediate=0 if d_intermediate == 0: mlp_cls = nn.Identity else: mlp_cls = partial( GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs ) block = block_cls( dim=d_model, mixer_cls=mixer_cls, mlp_cls=mlp_cls, norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, ) block.layer_idx = layer_idx return block def create_axial_block( d_model, d_intermediate, use_mamba2, axis, ssm_cfg=None, norm_epsilon=1e-5, rms_norm=False, residual_in_fp32=False, fused_add_norm=False, layer_idx=None, bidirectional=True, bidirectional_strategy="add", bidirectional_weight_tie=True, rcps=False, device=None, dtype=None, ): """Create an axial Caduceus block composed of two AxialCaduceus blocks, one for row and one for columns. Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py """ if ssm_cfg is None: ssm_cfg = {} factory_kwargs = {"device": device, "dtype": dtype} bidirectional_kwargs = { "bidirectional": bidirectional, "bidirectional_strategy": bidirectional_strategy, "bidirectional_weight_tie": bidirectional_weight_tie, } #mixer_cls = partial( # Mamba2 if ssm_layer == "Mamba2" else Mamba, # layer_idx=layer_idx, # **ssm_cfg, # **factory_kwargs #) mixer_cls = partial( AxialBiMambaWrapper, use_mamba2=use_mamba2, axis=axis, layer_idx=layer_idx, **ssm_cfg, **bidirectional_kwargs, **factory_kwargs, ) norm_cls = partial( nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs ) block_cls = RCPSMambaBlock if rcps else Block if d_intermediate == 0: mlp_cls = nn.Identity else: mlp_cls = partial( GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs ) block = block_cls( dim=d_model, mixer_cls=mixer_cls, mlp_cls=mlp_cls, norm_cls=norm_cls, fused_add_norm=fused_add_norm, residual_in_fp32=residual_in_fp32, ) block.layer_idx = layer_idx return block def create_attention_block( d_model: int, n_heads: int, attention_dropout: float, block_dropout: float, layer_idx=None, device=None, dtype=None, ): """Create an RowAttention block from MSATransformer.""" raise NotImplementedError() # factory_kwargs = {"device": device, "dtype": dtype} # layer_cls = RowSelfAttention( # embed_dim=d_model, num_heads=n_heads, dropout=attention_dropout # ) # block = NormalizedResidualBlock( # layer=layer_cls, embedding_dim=d_model, dropout=block_dropout # ) # Wraps attention with residual connection, layer norm, and drop out. NOTE: No mixer in this block # block = block.to(device) # block.layer_idx = layer_idx # return block class BiMambaWrapper(nn.Module): """Thin wrapper around Mamba to support bi-directionality.""" def __init__( self, d_model: int, bidirectional: bool = True, bidirectional_strategy: Optional[str] = "add", bidirectional_weight_tie: bool = True, **mamba_kwargs, ): super().__init__() if bidirectional and bidirectional_strategy is None: bidirectional_strategy = "add" # Default strategy: `add` if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]: raise NotImplementedError( f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!" ) self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.mamba_fwd = Mamba(d_model=d_model, **mamba_kwargs) if bidirectional: self.mamba_rev = Mamba(d_model=d_model, **mamba_kwargs) if ( bidirectional_weight_tie ): # Tie in and out projections (where most of param count lies) self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias else: self.mamba_rev = None def forward(self, hidden_states, inference_params=None): """Bidirectional-enabled forward pass hidden_states: (B, L, D) Returns: same shape as hidden_states """ out = self.mamba_fwd(hidden_states, inference_params=inference_params) if self.bidirectional: out_rev = self.mamba_rev( hidden_states.flip( dims=(1,) ), # Flip along the sequence length dimension inference_params=inference_params, ).flip(dims=(1,)) # Flip back for combining with forward hidden states if self.bidirectional_strategy == "add": out = out + out_rev elif self.bidirectional_strategy == "ew_multiply": out = out * out_rev else: raise NotImplementedError( f"`{self.bidirectional_strategy}` for bi-directionality not implemented!" ) return out class AxialBiMambaWrapper(nn.Module): """Thin wrapper around BiMamba to support running and aggregating over rows. axis=1 for RowMamba, axis=2 for column Mamba """ def __init__( self, d_model: int, use_mamba2: bool, bidirectional: bool = True, bidirectional_strategy: Optional[str] = "add", bidirectional_weight_tie: bool = True, axis: int = 1, **mamba_kwargs, ): super().__init__() if bidirectional and bidirectional_strategy is None: bidirectional_strategy = "add" # Default strategy: `add` if bidirectional and bidirectional_strategy not in ["add", "ew_multiply"]: raise NotImplementedError( f"`{bidirectional_strategy}` strategy for bi-directionality is not implemented!" ) self.bidirectional = bidirectional self.bidirectional_strategy = bidirectional_strategy self.mamba_fwd = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs) self.axis = axis if bidirectional: self.mamba_rev = Mamba2(d_model=d_model, **mamba_kwargs) if use_mamba2 else Mamba(d_model=d_model, **mamba_kwargs) if ( bidirectional_weight_tie ): # Tie in and out projections (where most of param count lies) self.mamba_rev.in_proj.weight = self.mamba_fwd.in_proj.weight self.mamba_rev.in_proj.bias = self.mamba_fwd.in_proj.bias self.mamba_rev.out_proj.weight = self.mamba_fwd.out_proj.weight self.mamba_rev.out_proj.bias = self.mamba_fwd.out_proj.bias else: self.mamba_rev = None def forward(self, hidden_states, inference_params=None): """Bidirectional-enabled forward pass hidden_states: (B, R, C, D) Returns: same shape as hidden_states """ def apply_mamba(x): out = self.mamba_fwd(x, inference_params=inference_params) if self.bidirectional: out_rev = self.mamba_rev( x.flip( dims=(1,) ), # Flip along the sequence length dimension inference_params=inference_params, ).flip(dims=(1,)) # Flip back for combining with forward hidden states if self.bidirectional_strategy == "add": out = out + out_rev elif self.bidirectional_strategy == "ew_multiply": out = out * out_rev else: raise NotImplementedError( f"`{self.bidirectional_strategy}` for bi-directionality not implemented!" ) return out batch, rows, columns, hidden_dim = hidden_states.size() if self.axis == 1: # row mamba hidden_states = hidden_states.permute(1, 0, 2, 3) axis_len = rows elif self.axis == 2: hidden_states = hidden_states.permute(2, 0, 1, 3) axis_len = columns outs = [] ## parllel #outs = parallel_apply([apply_mamba for _ in range(axis_len)], hidden_states.unbind(0)) ## reshape outs = apply_mamba(hidden_states.reshape(axis_len * batch, -1, hidden_dim)) out = outs.reshape(axis_len, batch, -1, hidden_dim) ### forlop #for axis_idx in range(axis_len): #tmp_hidden_states = hidden_states[axis_idx, ...] #out = apply_mamba(tmp_hidden_states) #outs.append(out) #out = torch.stack(outs, dim=0) if self.axis == 1: # row mamba out = out.permute(1, 0, 2, 3) elif self.axis == 2: # [C, B, R, D] out = out.permute(1, 2, 0, 3) return out class CaduceusEmbeddings(nn.Module): def __init__( self, config: CaduceusConfig, device=None, dtype=None, ): super().__init__() factory_kwargs = {"device": device, "dtype": dtype} if config.rcps: self.word_embeddings = RCPSEmbedding( config.vocab_size, config.d_model, config.complement_map, **factory_kwargs, ) else: self.word_embeddings = nn.Embedding( config.vocab_size, config.d_model, **factory_kwargs ) def forward(self, input_ids): """ input_ids: (batch, seqlen) """ return self.word_embeddings(input_ids) class CaduceusMixerModel(nn.Module): def __init__( self, config: CaduceusConfig, device=None, dtype=None, ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.fused_add_norm = config.fused_add_norm self.rcps = config.rcps self.residual_in_fp32 = config.residual_in_fp32 self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) # Mamba changes the order of residual and layer norm: # Instead of LN -> Attn / MLP -> Add, we do: # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and # the main branch (output of MLP / Mixer). The model definition is unchanged. # This is for performance reason: we can fuse add + layer_norm. if config.fused_add_norm: if layer_norm_fn is None or rms_norm_fn is None: raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") self.layers = nn.ModuleList( [ create_block( config.d_model, ssm_cfg=config.ssm_cfg, norm_epsilon=config.norm_epsilon, rms_norm=config.rms_norm, residual_in_fp32=config.residual_in_fp32, fused_add_norm=config.fused_add_norm, layer_idx=i, bidirectional=config.bidirectional, bidirectional_strategy=config.bidirectional_strategy, bidirectional_weight_tie=config.bidirectional_weight_tie, rcps=config.rcps, **factory_kwargs, ) for i in range(config.n_layer) ] ) norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( config.d_model, eps=config.norm_epsilon, **factory_kwargs ) self.norm_f = ( norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f) ) def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False): """Mixer forward.""" all_hidden_states = [] if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids) residual = None for layer in self.layers: if output_hidden_states: all_hidden_states.append(hidden_states) # TODO: Add support for gradient checkpointing hidden_states, residual = layer( hidden_states, residual, inference_params=None ) if not self.fused_add_norm: if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states = self.norm_f( hidden_states, residual=residual, prenorm=False ) else: residual = ( (hidden_states + residual) if residual is not None else hidden_states ) hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) else: fused_add_norm_fn = ( rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn ) if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states_fwd = fused_add_norm_fn( hidden_states[..., : hidden_states.shape[-1] // 2], self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., : hidden_states.shape[-1] // 2], prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states_rc = fused_add_norm_fn( hidden_states[..., hidden_states.shape[-1] // 2 :].flip( dims=[-2, -1] ), self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., hidden_states.shape[-1] // 2 :].flip( dims=[-2, -1] ), prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states = torch.cat( [hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1 ) else: # Set prenorm=False here since we don't need the residual hidden_states = fused_add_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, ) if output_hidden_states: all_hidden_states.append(hidden_states) return hidden_states, all_hidden_states class AxialCaduceusMixerModel(nn.Module): def __init__( self, config: CaduceusConfig, device=None, dtype=None, ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.fused_add_norm = config.fused_add_norm self.rcps = config.rcps self.residual_in_fp32 = config.residual_in_fp32 self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) self.pos_embeddings = None self.add_pos = False if config.pos_embeddings == 'Linear': self.add_pos = True self.pos_embeddings = nn.Linear(in_features=1, out_features=config.d_model, **factory_kwargs) elif config.pos_embeddings == 'Sinusoidal': self.pos_embeddings = partial(sinusoidal_encoding, d_model=config.d_model, **factory_kwargs) # Mamba changes the order of residual and layer norm: # Instead of LN -> Attn / MLP -> Add, we do: # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and # the main branch (output of MLP / Mixer). The model definition is unchanged. # This is for performance reason: we can fuse add + layer_norm. if config.fused_add_norm: if layer_norm_fn is None or rms_norm_fn is None: raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") row_first = 0 #assume col ssm first if config.row_first: #row first row_first = 1 self.layers = nn.ModuleList( [ create_axial_block( d_model=config.d_model, d_intermediate=config.d_intermediate, use_mamba2=config.use_mamba2, axis=((i + row_first) % 2) + 1, # (i%2) + 1 for columns first ssm_cfg=config.ssm_cfg, norm_epsilon=config.norm_epsilon, rms_norm=config.rms_norm, residual_in_fp32=config.residual_in_fp32, fused_add_norm=config.fused_add_norm, layer_idx=i, bidirectional=config.bidirectional, bidirectional_strategy=config.bidirectional_strategy, bidirectional_weight_tie=config.bidirectional_weight_tie, rcps=config.rcps, **factory_kwargs, ) for i in range(config.n_layer * 2) ] ) norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( config.d_model, eps=config.norm_epsilon, **factory_kwargs ) self.norm_f = ( norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f) ) def forward(self, input_ids, inputs_embeds=None, input_positions=None, output_hidden_states=False): """Mixer forward.""" all_hidden_states = [] if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids) if self.pos_embeddings is not None: if self.add_pos: pos_embedding = self.pos_embeddings(input_positions[...,None]) #[B, L, D] hidden_states = torch.cat([pos_embedding[:,None, ...], hidden_states], dim=1) else: p_B, p_L = input_positions.size() B, R, L, D = hidden_states.size() assert p_B == B assert p_L == L pos_embedding = self.pos_embeddings(positions=input_positions)[:,None, ...] # [B, 1, L, D] hidden_states += pos_embedding residual = None for layer in self.layers: if output_hidden_states: all_hidden_states.append(hidden_states) # TODO: Add support for gradient checkpointing hidden_states, residual = layer( hidden_states, residual, inference_params=None ) if not self.fused_add_norm: if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states = self.norm_f( hidden_states, residual=residual, prenorm=False ) else: residual = ( (hidden_states + residual) if residual is not None else hidden_states ) hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) else: fused_add_norm_fn = ( rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn ) if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states_fwd = fused_add_norm_fn( hidden_states[..., : hidden_states.shape[-1] // 2], self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., : hidden_states.shape[-1] // 2], prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states_rc = fused_add_norm_fn( hidden_states[..., hidden_states.shape[-1] // 2 :].flip( dims=[-2, -1] ), self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., hidden_states.shape[-1] // 2 :].flip( dims=[-2, -1] ), prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states = torch.cat( [hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1 ) else: # Set prenorm=False here since we don't need the residual hidden_states = fused_add_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, ) if output_hidden_states: all_hidden_states.append(hidden_states) if self.pos_embeddings is not None and self.add_pos: #removce the positional embeddings form the returned MSA hidden_states = hidden_states[:,1:,...] return hidden_states, all_hidden_states class MixedAxialCaduceusMixerModel(nn.Module): """ A model that swtiches between Caducues and Standard attention mechanisms """ def __init__( self, config: MixedCaduceusConfig, device=None, dtype=None, ) -> None: super().__init__() factory_kwargs = {"device": device, "dtype": dtype} self.fused_add_norm = config.fused_add_norm self.rcps = config.rcps self.residual_in_fp32 = config.residual_in_fp32 self.embeddings = CaduceusEmbeddings(config, **factory_kwargs) # Mamba changes the order of residual and layer norm: # Instead of LN -> Attn / MLP -> Add, we do: # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and # the main branch (output of MLP / Mixer). The model definition is unchanged. # This is for performance reason: we can fuse add + layer_norm. if config.fused_add_norm: if layer_norm_fn is None or rms_norm_fn is None: raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") layers = [] for i in range(config.n_layer * 2): axis = ((i + 1) % 2) + 1 # 1 for rows, 2 for columns, columns first. block = None if axis == 1: block = create_attention_block( d_model=config.attn_d_model, n_heads=config.attn_n_heads, attention_dropout=config.attn_attn_dropout, block_dropout=config.attn_block_dropout, layer_idx=i, **factory_kwargs, ) elif axis == 2: block = create_axial_block( d_model=config.d_model, d_intermediate=config.d_intermediate, use_mamba2=config.use_mamba2, axis=axis, # always columns ssm_cfg=config.ssm_cfg, norm_epsilon=config.norm_epsilon, rms_norm=config.rms_norm, residual_in_fp32=config.residual_in_fp32, fused_add_norm=config.fused_add_norm, layer_idx=i, bidirectional=config.bidirectional, bidirectional_strategy=config.bidirectional_strategy, bidirectional_weight_tie=config.bidirectional_weight_tie, rcps=config.rcps, **factory_kwargs, ) layers.append(block) self.layers = nn.ModuleList(layers) norm_f = (nn.LayerNorm if not config.rms_norm else RMSNorm)( config.d_model, eps=config.norm_epsilon, **factory_kwargs ) self.norm_f = ( norm_f if (config.fused_add_norm or not config.rcps) else RCPSAddNormWrapper(norm_f) ) def forward(self, input_ids, inputs_embeds=None, output_hidden_states=False): """Mixer forward.""" all_hidden_states = [] if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embeddings(input_ids) residual = None for layer in self.layers: if output_hidden_states: all_hidden_states.append(hidden_states) # TODO: Add support for gradient checkpointing hidden_states, residual = layer( hidden_states, residual, inference_params=None ) if not self.fused_add_norm: if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states = self.norm_f( hidden_states, residual=residual, prenorm=False ) else: residual = ( (hidden_states + residual) if residual is not None else hidden_states ) hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) else: fused_add_norm_fn = ( rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn ) if self.rcps: # Set prenorm=False here since we don't need the residual hidden_states_fwd = fused_add_norm_fn( hidden_states[..., : hidden_states.shape[-1] // 2], self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., : hidden_states.shape[-1] // 2], prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states_rc = fused_add_norm_fn( hidden_states[..., hidden_states.shape[-1] // 2 :].flip( dims=[-2, -1] ), self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual[..., hidden_states.shape[-1] // 2 :].flip( dims=[-2, -1] ), prenorm=False, residual_in_fp32=self.residual_in_fp32, ) hidden_states = torch.cat( [hidden_states_fwd, hidden_states_rc.flip(dims=[-2, -1])], dim=-1 ) else: # Set prenorm=False here since we don't need the residual hidden_states = fused_add_norm_fn( hidden_states, self.norm_f.weight, self.norm_f.bias, eps=self.norm_f.eps, residual=residual, prenorm=False, residual_in_fp32=self.residual_in_fp32, ) if output_hidden_states: all_hidden_states.append(hidden_states) return hidden_states, all_hidden_states def cross_entropy(logits, y, ignore_index=-100): """Cross entropy loss.""" logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) return F.cross_entropy(logits, y, ignore_index=ignore_index) def weighted_cross_entropy(logits, y, loss_weights, ignore_index=-100): """Weighted cross entropy loss (discounts certain tokens, e.g., repeated base pairs in genome).""" logits = logits.view(-1, logits.shape[-1]) y = y.view(-1) ce = F.cross_entropy(logits, y, ignore_index=ignore_index, reduction="none") loss_weights = loss_weights.view(-1) loss_weights[y == ignore_index] = 0.0 # TODO: Follows GPN implementation, but should we remove weight normalization? return (ce * (loss_weights / loss_weights.sum())).sum() class CaduceusPreTrainedModel(PreTrainedModel): """PreTrainedModel wrapper for Caduceus backbone.""" config_class = CaduceusConfig base_model_prefix = "caduceus" supports_gradient_checkpointing = False _no_split_modules = ["BiMambaWrapper"] def _init_weights( self, module, initializer_range=0.02, # Now only used for embedding layer. **kwargs, ): """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py""" n_layer = self.config.n_layer initialized_cfg = ( self.config.initializer_cfg if self.config.initializer_cfg is not None else {} ) rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True) initializer_range = initialized_cfg.get("initializer_range", initializer_range) n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1) if isinstance(module, nn.Linear): if module.bias is not None: if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of # residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) class AxialCaduceusPreTrainedModel(PreTrainedModel): """PreTrainedModel wrapper for Caduceus backbone.""" config_class = AxialCaduceusConfig base_model_prefix = "axial_caduceus" supports_gradient_checkpointing = False _no_split_modules = ["BiMambaWrapper"] def _init_weights( self, module, initializer_range=0.02, # Now only used for embedding layer. **kwargs, ): """Adapted from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py""" n_layer = self.config.n_layer initialized_cfg = ( self.config.initializer_cfg if self.config.initializer_cfg is not None else {} ) rescale_prenorm_residual = initialized_cfg.get("rescale_prenorm_residual", True) initializer_range = initialized_cfg.get("initializer_range", initializer_range) n_residuals_per_layer = initialized_cfg.get("n_residuals_per_layer", 1) if isinstance(module, nn.Linear): if module.bias is not None: if not getattr(module.bias, "_no_reinit", False): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=initializer_range) if rescale_prenorm_residual: # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: # > A modified initialization which accounts for the accumulation on the residual path with model depth. # > Scale the weights of residual layers at initialization by a factor of 1/√N where N is the # of # residual layers. # > -- GPT-2 :: https://openai.com/blog/better-language-models/ # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py for name, p in module.named_parameters(): if name in ["out_proj.weight", "fc2.weight"]: # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) # We need to reinit p since this code could be called multiple times # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): p /= math.sqrt(n_residuals_per_layer * n_layer) class Caduceus(CaduceusPreTrainedModel): """Caduceus model that can be instantiated using HF patterns.""" def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config) if config.rcps: assert ( config.complement_map is not None ), "Complement map must be provided for RCPS." # Adjust vocab size and complement maps if vocab padding is set. if config.vocab_size % config.pad_vocab_size_multiple != 0: config.vocab_size += config.pad_vocab_size_multiple - ( config.vocab_size % config.pad_vocab_size_multiple ) if config.complement_map is not None and config.vocab_size > len( config.complement_map ): for i in range(len(config.complement_map), config.vocab_size): config.complement_map[i] = i self.config = config factory_kwargs = {"device": device, "dtype": dtype} self.backbone = CaduceusMixerModel(config, **factory_kwargs, **kwargs) def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states, all_hidden_states = self.backbone( input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, ) if return_dict: return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=all_hidden_states if output_hidden_states else None, ) elif output_hidden_states: return hidden_states, all_hidden_states else: return hidden_states class AxialCaduceus(AxialCaduceusPreTrainedModel): """Caduceus model that can be instantiated using HF patterns.""" def __init__(self, config: AxialCaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config) if config.rcps: assert ( config.complement_map is not None ), "Complement map must be provided for RCPS." # Adjust vocab size and complement maps if vocab padding is set. if config.vocab_size % config.pad_vocab_size_multiple != 0: config.vocab_size += config.pad_vocab_size_multiple - ( config.vocab_size % config.pad_vocab_size_multiple ) if config.complement_map is not None and config.vocab_size > len( config.complement_map ): for i in range(len(config.complement_map), config.vocab_size): config.complement_map[i] = i self.config = config factory_kwargs = {"device": device, "dtype": dtype} self.backbone = AxialCaduceusMixerModel(config, **factory_kwargs, **kwargs) def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, input_positions: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states, all_hidden_states = self.backbone( input_ids, inputs_embeds=inputs_embeds, input_positions=input_positions, output_hidden_states=output_hidden_states, ) if return_dict: return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=all_hidden_states if output_hidden_states else None, ) elif output_hidden_states: return hidden_states, all_hidden_states else: return hidden_states class MixedAxialCaduceus(CaduceusPreTrainedModel): """Mixed Caduceus/Attention model that can be instantiated using HF patterns.""" def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config) if config.rcps: assert ( config.complement_map is not None ), "Complement map must be provided for RCPS." # Adjust vocab size and complement maps if vocab padding is set. if config.vocab_size % config.pad_vocab_size_multiple != 0: config.vocab_size += config.pad_vocab_size_multiple - ( config.vocab_size % config.pad_vocab_size_multiple ) if config.complement_map is not None and config.vocab_size > len( config.complement_map ): for i in range(len(config.complement_map), config.vocab_size): config.complement_map[i] = i self.config = config factory_kwargs = {"device": device, "dtype": dtype} self.backbone = MixedAxialCaduceusMixerModel(config, **factory_kwargs, **kwargs) def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple, BaseModelOutputWithNoAttention]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) hidden_states, all_hidden_states = self.backbone( input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, ) if return_dict: return BaseModelOutputWithNoAttention( last_hidden_state=hidden_states, hidden_states=all_hidden_states if output_hidden_states else None, ) elif output_hidden_states: return hidden_states, all_hidden_states else: return hidden_states class CaduceusForMaskedLM(CaduceusPreTrainedModel): """HF-compatible Caduceus model for masked language modeling.""" def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config, **kwargs) factory_kwargs = {"device": device, "dtype": dtype} self.caduceus = Caduceus(config, **factory_kwargs, **kwargs) if config.rcps: self.lm_head = RCPSLMHead( complement_map=self.config.complement_map, # Use caduceus config as it might have been updated vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated true_dim=config.d_model, dtype=dtype, ) else: self.lm_head = nn.Linear( config.d_model, self.config.vocab_size, # Use caduceus config as it might have been updated bias=False, **factory_kwargs, ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.caduceus.backbone.embeddings.word_embeddings def set_input_embeddings(self, value): if self.config.rcps: raise NotImplementedError( "Setting input embeddings for RCPS LM is not supported." ) self.caduceus.backbone.embeddings.word_embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): """Overrides output embeddings.""" if self.config.rcps: raise NotImplementedError( "Setting output embeddings for RCPS LM is not supported." ) self.lm_head = new_embeddings def tie_weights(self): """Tie weights, accounting for RCPS.""" if self.config.rcps: self.lm_head.set_weight(self.get_input_embeddings().weight) else: super().tie_weights() def get_decoder(self): """Get decoder (backbone) for the model.""" return self.caduceus def set_decoder(self, decoder): """Set decoder (backbone) for the model.""" self.caduceus = decoder def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, loss_weights: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MaskedLMOutput]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.caduceus( input_ids=input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: if loss_weights is not None: loss = weighted_cross_entropy( logits, labels, loss_weights, ignore_index=self.config.pad_token_id ) else: loss = cross_entropy( logits, labels, ignore_index=self.config.pad_token_id ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, ) class AxialCaduceusForMaskedLM(AxialCaduceusPreTrainedModel): """HF-compatible Caduceus model for masked language modeling.""" def __init__(self, config: CaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config, **kwargs) factory_kwargs = {"device": device, "dtype": dtype} self.caduceus = AxialCaduceus(config, **factory_kwargs, **kwargs) if config.rcps: self.lm_head = RCPSLMHead( complement_map=self.config.complement_map, # Use caduceus config as it might have been updated vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated true_dim=config.d_model, dtype=dtype, ) else: self.lm_head = nn.Linear( config.d_model, self.config.vocab_size, # Use caduceus config as it might have been updated bias=False, **factory_kwargs, ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.caduceus.backbone.embeddings.word_embeddings def set_input_embeddings(self, value): if self.config.rcps: raise NotImplementedError( "Setting input embeddings for RCPS LM is not supported." ) self.caduceus.backbone.embeddings.word_embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): """Overrides output embeddings.""" if self.config.rcps: raise NotImplementedError( "Setting output embeddings for RCPS LM is not supported." ) self.lm_head = new_embeddings def tie_weights(self): """Tie weights, accounting for RCPS.""" if self.config.rcps: self.lm_head.set_weight(self.get_input_embeddings().weight) else: super().tie_weights() def get_decoder(self): """Get decoder (backbone) for the model.""" return self.caduceus def set_decoder(self, decoder): """Set decoder (backbone) for the model.""" self.caduceus = decoder def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, input_positions: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None, loss_weights: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MaskedLMOutput]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.caduceus( input_ids=input_ids, inputs_embeds=inputs_embeds, input_positions=input_positions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: if loss_weights is not None: loss = weighted_cross_entropy( logits, labels, loss_weights, ignore_index=self.config.pad_token_id ) else: loss = cross_entropy( logits, labels, ignore_index=self.config.pad_token_id ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, ) class MixedAxialCaduceusForMaskedLM(CaduceusPreTrainedModel): """HF-compatible Caduceus model for masked language modeling.""" def __init__(self, config: MixedCaduceusConfig, device=None, dtype=None, **kwargs): super().__init__(config, **kwargs) factory_kwargs = {"device": device, "dtype": dtype} self.caduceus = MixedAxialCaduceus(config, **factory_kwargs, **kwargs) if config.rcps: self.lm_head = RCPSLMHead( complement_map=self.config.complement_map, # Use caduceus config as it might have been updated vocab_size=self.config.vocab_size, # Use caduceus config as it might have been updated true_dim=config.d_model, dtype=dtype, ) else: self.lm_head = nn.Linear( config.d_model, self.config.vocab_size, # Use caduceus config as it might have been updated bias=False, **factory_kwargs, ) # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.caduceus.backbone.embeddings.word_embeddings def set_input_embeddings(self, value): if self.config.rcps: raise NotImplementedError( "Setting input embeddings for RCPS LM is not supported." ) self.caduceus.backbone.embeddings.word_embeddings = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): """Overrides output embeddings.""" if self.config.rcps: raise NotImplementedError( "Setting output embeddings for RCPS LM is not supported." ) self.lm_head = new_embeddings def tie_weights(self): """Tie weights, accounting for RCPS.""" if self.config.rcps: self.lm_head.set_weight(self.get_input_embeddings().weight) else: super().tie_weights() def get_decoder(self): """Get decoder (backbone) for the model.""" return self.caduceus def set_decoder(self, decoder): """Set decoder (backbone) for the model.""" self.caduceus = decoder def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, loss_weights: Optional[torch.FloatTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MaskedLMOutput]: """HF-compatible forward method.""" output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.caduceus( input_ids=input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() loss = None if labels is not None: if loss_weights is not None: loss = weighted_cross_entropy( logits, labels, loss_weights, ignore_index=self.config.pad_token_id ) else: loss = cross_entropy( logits, labels, ignore_index=self.config.pad_token_id ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, ) class CaduceusForSequenceClassification(CaduceusPreTrainedModel): def __init__( self, config: CaduceusConfig, pooling_strategy: str = "mean", conjoin_train: bool = False, conjoin_eval: bool = False, device=None, dtype=None, **kwargs, ): super().__init__(config, **kwargs) if pooling_strategy not in ["mean", "max", "first", "last"]: raise NotImplementedError( f"Pooling strategy `{pooling_strategy}` not implemented." ) self.pooling_strategy = pooling_strategy factory_kwargs = {"device": device, "dtype": dtype} self.num_labels = kwargs.get("num_labels", config.num_labels) self.caduceus = Caduceus(config, **factory_kwargs, **kwargs) self.score = nn.Linear(config.d_model, self.num_labels, bias=False) self.conjoin_train = conjoin_train self.conjoin_eval = conjoin_eval # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.caduceus.backbone.embeddings.word_embeddings def set_input_embeddings(self, value): if self.config.rcps: raise NotImplementedError( "Setting input embeddings for RCPS LM is not supported." ) self.caduceus.backbone.embeddings.word_embeddings = value def pool_hidden_states(self, hidden_states, sequence_length_dim=1): """Pools hidden states along sequence length dimension.""" if ( self.pooling_strategy == "mean" ): # Mean pooling along sequence length dimension return hidden_states.mean(dim=sequence_length_dim) if ( self.pooling_strategy == "max" ): # Max pooling along sequence length dimension return hidden_states.max(dim=sequence_length_dim).values if ( self.pooling_strategy == "last" ): # Use embedding of last token in the sequence return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[ -1, ... ] if ( self.pooling_strategy == "first" ): # Use embedding of first token in the sequence return hidden_states.moveaxis(hidden_states, sequence_length_dim, 0)[0, ...] def forward( self, input_ids: torch.LongTensor = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) # Get hidden representations from the backbone if self.config.rcps: # Hidden states have 2 * d_model channels for RCPS transformer_outputs = self.caduceus( input_ids, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = torch.stack( [ transformer_outputs[0][..., : self.config.d_model], torch.flip( transformer_outputs[0][..., self.config.d_model :], dims=[1, 2] ), ], dim=-1, ) elif self.conjoin_train or ( self.conjoin_eval and not self.training ): # For conjoining / post-hoc conjoining assert input_ids is not None, "`input_ids` must be provided for conjoining." assert ( input_ids.ndim == 3 ), "`input_ids` must be 3D tensor: channels corresponds to forward and rc strands." transformer_outputs = self.caduceus( input_ids[..., 0], inputs_embeds=None, output_hidden_states=output_hidden_states, return_dict=return_dict, ) transformer_outputs_rc = self.caduceus( input_ids[..., 1], inputs_embeds=None, output_hidden_states=output_hidden_states, return_dict=return_dict, ) # Stack along channel dimension (dim=-1) hidden_states = torch.stack( [transformer_outputs[0], transformer_outputs_rc[0]], dim=-1 ) else: transformer_outputs = self.caduceus( input_ids, inputs_embeds=None, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = transformer_outputs[0] # Pool and get logits pooled_hidden_states = self.pool_hidden_states(hidden_states) # Potentially run `score` twice (with parameters shared) for conjoining if ( hidden_states.ndim == 4 ): # bsz, seq_len, hidden_dim, 2 where last channel has the stacked fwd and rc reps logits_fwd = self.score(pooled_hidden_states[..., 0]) logits_rc = self.score(pooled_hidden_states[..., 1]) logits = (logits_fwd + logits_rc) / 2 else: logits = self.score(pooled_hidden_states) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and ( labels.dtype == torch.long or labels.dtype == torch.int ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = F.mse_loss(logits.squeeze(), labels.squeeze()) else: loss = F.mse_loss(logits, labels) elif self.config.problem_type == "single_label_classification": loss = F.cross_entropy( logits.view(-1, self.num_labels), labels.view(-1) ) elif self.config.problem_type == "multi_label_classification": loss = F.binary_cross_entropy_with_logits(logits, labels) if not return_dict: output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=transformer_outputs.hidden_states, )