|
|
| import torch |
| if not hasattr(torch.library, 'wrap_triton'): |
| def wrap_triton(fn): |
| return fn |
| torch.library.wrap_triton = wrap_triton |
|
|
| |
| import torch._dynamo |
| torch._dynamo.config.capture_scalar_outputs = True |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Optional, Tuple, Union |
|
|
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPast, SequenceClassifierOutput |
|
|
| import bert_padding |
| from attention import FlexBertUnpadRopeAttention |
|
|
| from torch.distributed import init_process_group, destroy_process_group |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import torch.distributed as dist |
|
|
| try: |
| from liger_kernel.transformers import LigerLayerNorm |
| LayerNormClass = LigerLayerNorm |
| except ImportError: |
| LayerNormClass = nn.LayerNorm |
|
|
|
|
|
|
| |
|
|
|
|
| class CustomTransformerConfig(PretrainedConfig): |
| """ |
| Configuration class for CustomTransformer model. |
| |
| This class stores the configuration of a CustomTransformer model and is compatible |
| with HuggingFace's transformers library. It replaces the old ModelConfig dataclass. |
| """ |
| model_type = "custom_transformer" |
|
|
| |
| auto_map = { |
| "AutoConfig": "model.CustomTransformerConfig", |
| "AutoModel": "model.CustomTransformerModel", |
| "AutoModelForMaskedLM": "model.CustomTransformerForMaskedLM", |
| "AutoModelForSequenceClassification": "model.CustomTransformerForSequenceClassification", |
| } |
|
|
| def __init__( |
| self, |
| vocab_size: int = 50368, |
| num_dims: int = 768, |
| num_heads: int = 12, |
| num_kv_heads: int = 12, |
| num_layers: int = 12, |
| ffn_hidden_dims: int = 1536, |
| layernorm_eps: float = 1e-6, |
| attention_probs_dropout_prob: float = 0.1, |
| attn_qkv_bias: bool = False, |
| attn_out_bias: bool = False, |
| attn_out_dropout_prob: float = 0.0, |
| global_attn_every_n_layers: int = 3, |
| sliding_window: int = 128, |
| rotary_emb_base: int = 10000, |
| context_len: int = 128, |
| use_cache: bool = False, |
| use_flash: bool = True, |
| use_moe: bool = True, |
| moe_num_experts: int = 15, |
| moe_routed_experts: int = 1, |
| moe_eps: float = 1e-6, |
| moe_aux_loss_coef: float = 0.01, |
| moe_shared_experts: int = 1, |
| use_lossfreebalance: bool = True, |
| pad_token_id: int = 0, |
| bos_token_id: int = 1, |
| eos_token_id: int = 2, |
| mask_token_id: int = 3, |
| rope_theta: float = 1e5, |
| ffn_dim_multiplier: Optional[int] = None, |
| rotary_emb_dim: Optional[int] = None, |
| local_attn_rotary_emb_base: int = -1, |
| local_attn_rotary_emb_dim: Optional[int] = None, |
| rotary_emb_scale_base: Optional[float] = None, |
| rotary_emb_interleaved: bool = False, |
| use_fa2: Optional[bool] = None, |
| deterministic_fa2: bool = False, |
| use_sdpa_attn_mask: bool = False, |
| num_labels: int = 2, |
| classifier_dropout: Optional[float] = None, |
| **kwargs |
| ): |
| """Initialize CustomTransformerConfig.""" |
| super().__init__( |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| **kwargs |
| ) |
|
|
| self.vocab_size = vocab_size |
| self.num_dims = num_dims |
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads |
| self.num_layers = num_layers |
| self.ffn_hidden_dims = ffn_hidden_dims |
| self.layernorm_eps = layernorm_eps |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| self.attn_qkv_bias = attn_qkv_bias |
| self.attn_out_bias = attn_out_bias |
| self.attn_out_dropout_prob = attn_out_dropout_prob |
| self.global_attn_every_n_layers = global_attn_every_n_layers |
| self.sliding_window = sliding_window |
| self.rotary_emb_base = rotary_emb_base |
| self.context_len = context_len |
| self.use_cache = use_cache |
| self.use_flash = use_flash |
| self.use_moe = use_moe |
| self.moe_num_experts = moe_num_experts |
| self.moe_routed_experts = moe_routed_experts |
| self.moe_eps = moe_eps |
| self.moe_aux_loss_coef = moe_aux_loss_coef |
| self.moe_shared_experts = moe_shared_experts |
| self.use_lossfreebalance = use_lossfreebalance |
| self.mask_token_id = mask_token_id |
| self.rope_theta = rope_theta |
| self.ffn_dim_multiplier = ffn_dim_multiplier |
| self.rotary_emb_dim = rotary_emb_dim |
| self.local_attn_rotary_emb_base = local_attn_rotary_emb_base |
| self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim |
| self.rotary_emb_scale_base = rotary_emb_scale_base |
| self.rotary_emb_interleaved = rotary_emb_interleaved |
| self.use_fa2 = use_fa2 |
| self.deterministic_fa2 = deterministic_fa2 |
| self.use_sdpa_attn_mask = use_sdpa_attn_mask |
| self.num_labels = num_labels |
| self.classifier_dropout = classifier_dropout |
|
|
| |
| self.hidden_size = num_dims |
| self.num_attention_heads = num_heads |
| self.embedding_size = num_dims |
|
|
| |
| if self.use_fa2 is None: |
| self.use_fa2 = self.use_flash |
|
|
|
|
| |
| @dataclass |
| class ModelConfig: |
| vocab_size: int |
|
|
| num_dims: int |
| num_heads: int |
| num_kv_heads: int |
| num_layers: int |
| ffn_hidden_dims: int |
|
|
| context_len: int |
| use_cache: bool |
| use_flash: bool |
| use_moe: bool |
|
|
| moe_num_experts: int |
| moe_routed_experts: int |
| moe_eps: float = 1e-6 |
| moe_aux_loss_coef: float = 0.00 |
| moe_shared_experts: int = 0 |
| use_lossfreebalance: bool = False |
|
|
| layernorm_eps: float = 1e-6 |
| rope_theta: float = 1e5 |
|
|
| attention_probs_dropout_prob: float = 0.0 |
| attn_qkv_bias: bool = False |
| attn_out_bias: bool = False |
| attn_out_dropout_prob: float = 0.0 |
| global_attn_every_n_layers: int = 0 |
| sliding_window: int = -1 |
| rotary_emb_dim: Optional[int] = None |
| rotary_emb_base: Optional[float] = None |
| local_attn_rotary_emb_base: int = -1 |
| local_attn_rotary_emb_dim: Optional[int] = None |
| rotary_emb_scale_base: Optional[float] = None |
| rotary_emb_interleaved: bool = False |
| use_fa2: Optional[bool] = None |
| deterministic_fa2: bool = False |
| use_sdpa_attn_mask: bool = False |
| hidden_size: Optional[int] = None |
| num_attention_heads: Optional[int] = None |
| embedding_size: Optional[int] = None |
|
|
| ffn_dim_multiplier: Optional[int] = None |
|
|
| def __post_init__(self): |
| if self.hidden_size is None: |
| self.hidden_size = self.num_dims |
| if self.num_attention_heads is None: |
| self.num_attention_heads = self.num_heads |
| if self.rotary_emb_base is None: |
| self.rotary_emb_base = self.rope_theta |
| if self.use_fa2 is None: |
| self.use_fa2 = self.use_flash |
|
|
|
|
| |
|
|
| class FlexBertUnpadAttention(nn.Module): |
| """Thin wrapper that preserves the state_dict key path: block.attention.attn.* |
| |
| In ModernBERT-style global unpadding the data is already (total_nnz, dim) so |
| this wrapper just forwards directly to FlexBertUnpadRopeAttention without |
| any pad/unpad work. cu_seqlens, max_seqlen, indices, and attn_mask are |
| passed through from the Transformer level. |
| """ |
| def __init__(self, config, layer_id: Optional[int] = None): |
| super().__init__() |
| self.attn = FlexBertUnpadRopeAttention(config=config, layer_id=layer_id) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| indices: torch.Tensor, |
| attn_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Forward on already-unpadded data. |
| |
| Args: |
| hidden_states: (total_nnz, dim) |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| indices: (total_nnz,) |
| attn_mask: (batch, seq_len) |
| |
| Returns: |
| (total_nnz, dim) |
| """ |
| return self.attn( |
| hidden_states=hidden_states, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| indices=indices, |
| attn_mask=attn_mask, |
| ) |
|
|
|
|
| class FeedForward(nn.Module): |
| """Default Feed Forward Layer. Works on both 2D (total_nnz, dim) and 3D inputs.""" |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.hidden_dim = config.ffn_hidden_dims |
|
|
| self.w1 = nn.Linear(config.num_dims, self.hidden_dim, bias=False) |
| self.w2 = nn.Linear(self.hidden_dim, config.num_dims, bias=False) |
| self.w3 = nn.Linear(config.num_dims, self.hidden_dim, bias=False) |
| self.act = nn.GELU() |
|
|
| def forward(self, x: torch.Tensor): |
| return self.w2(self.act(self.w1(x)) * self.w3(x)), None |
|
|
|
|
| class FFNwMoE(nn.Module): |
| """ |
| Feed Forward with MoE with optional shared experts. |
| Works on 2D (total_nnz, dim) unpadded inputs. |
| |
| Uses batched_mm (torch.bmm) for expert dispatch. Expert weights are stored |
| as stacked nn.Parameters: (num_experts, out_dim, in_dim). Old checkpoints |
| with per-expert nn.Linear weights are automatically converted at load time |
| via _load_from_state_dict. |
| |
| Returns after forward: |
| output: Combined outputs from experts |
| aux_loss: Auxiliary loss tensor or routing metadata |
| """ |
| def __init__(self, config): |
| super().__init__() |
| self.hidden_dim = config.ffn_hidden_dims |
| self.num_dims = config.num_dims |
|
|
| self.moe_routed_experts = config.moe_routed_experts |
| self.moe_aux_loss_coef = config.moe_aux_loss_coef |
| self.moe_eps = config.moe_eps |
| self.moe_shared_experts = config.moe_shared_experts |
| self.num_experts = config.moe_num_experts |
|
|
| self.use_lossfreebalance = config.use_lossfreebalance |
|
|
| self.router = nn.Linear(config.num_dims, self.num_experts, bias=False) |
|
|
| |
| |
| |
| |
| self.w1_stacked = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, config.num_dims)) |
| self.w2_stacked = nn.Parameter(torch.empty(self.num_experts, config.num_dims, self.hidden_dim)) |
| self.w3_stacked = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, config.num_dims)) |
|
|
| |
| for i in range(self.num_experts): |
| nn.init.kaiming_uniform_(self.w1_stacked.data[i]) |
| nn.init.kaiming_uniform_(self.w2_stacked.data[i]) |
| nn.init.kaiming_uniform_(self.w3_stacked.data[i]) |
|
|
| |
| self.shared_experts = nn.ModuleList() |
| for _ in range(self.moe_shared_experts): |
| self.shared_experts.append( |
| nn.ModuleList([ |
| nn.Linear(config.num_dims, self.hidden_dim, bias=False), |
| nn.Linear(self.hidden_dim, config.num_dims, bias=False), |
| nn.Linear(config.num_dims, self.hidden_dim, bias=False) |
| ])) |
|
|
| |
| if self.use_lossfreebalance: |
| self.expert_biases = nn.Parameter(torch.zeros(self.num_experts)) |
|
|
| def forward(self, x: torch.Tensor): |
| |
| input_shape = x.shape |
| if x.ndim == 3: |
| c_batch_size, c_context_len, c_dim = input_shape |
| x_flat = x.view(-1, c_dim) |
| else: |
| x_flat = x |
| c_dim = x.shape[-1] |
|
|
| router_out = self.router(x_flat) |
| router_probs = F.softmax(router_out, dim=-1) |
|
|
| _, topk_indices = router_out.topk(self.moe_routed_experts, dim=-1) |
| self.last_topk_indices = topk_indices.detach() |
|
|
| aux_loss, topk_probs = self._compute_aux_loss(router_out, router_probs, topk_indices) |
|
|
| output = self._compute_expert_outputs(x_flat, topk_indices, topk_probs, router_probs) |
|
|
| if x.ndim == 3: |
| output = output.view(c_batch_size, c_context_len, c_dim) |
|
|
| return output, aux_loss |
|
|
| def _compute_aux_loss(self, router_out, router_probs, topk_indices): |
| """Computes the auxiliary loss based on whether loss-free balancing is used or not.""" |
| if not self.use_lossfreebalance: |
| topk_probs, _ = router_probs.topk(self.moe_routed_experts, dim=-1) |
| expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float() |
| density = expert_mask.mean(dim=0) |
| router_prob_mean = router_probs.mean(dim=0) |
| aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts |
|
|
| else: |
| router_out = router_out + self.expert_biases |
| router_probs = torch.sigmoid(router_out) |
| topk_probs = router_probs.gather(-1, topk_indices) |
| topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) |
|
|
| aux_loss = (router_probs, topk_indices) |
| return aux_loss, topk_probs |
|
|
| def _compute_expert_outputs(self, x_flat, topk_indices, topk_probs, router_probs): |
| """Compute expert outputs using sort-based dispatch with stacked weights. |
| |
| Sort tokens by expert, slice contiguous chunks, run each expert via |
| matmul on the stacked weight tensors. No weight duplication, minimal |
| memory overhead. |
| """ |
| num_tokens, dim = x_flat.shape |
|
|
| |
| flat_expert_ids = topk_indices.view(-1) |
| flat_probs = topk_probs.view(-1) |
| flat_token_ids = torch.arange(num_tokens, device=x_flat.device).unsqueeze(1).expand(-1, self.moe_routed_experts).reshape(-1) |
|
|
| |
| sorted_expert_ids, sort_indices = flat_expert_ids.sort(stable=True) |
| sorted_token_ids = flat_token_ids[sort_indices] |
| sorted_probs = flat_probs[sort_indices] |
|
|
| |
| sorted_x = x_flat[sorted_token_ids] |
|
|
| |
| expert_counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts) |
| expert_offsets = torch.zeros(self.num_experts + 1, dtype=torch.long, device=x_flat.device) |
| torch.cumsum(expert_counts, dim=0, out=expert_offsets[1:]) |
|
|
| |
| sorted_output = torch.zeros_like(sorted_x) |
| for expert_id in range(self.num_experts): |
| start = expert_offsets[expert_id].item() |
| end = expert_offsets[expert_id + 1].item() |
| if start == end: |
| continue |
| expert_input = sorted_x[start:end] |
| |
| h1 = F.linear(expert_input, self.w1_stacked[expert_id]) |
| h3 = F.linear(expert_input, self.w3_stacked[expert_id]) |
| h = F.gelu(h1) * h3 |
| sorted_output[start:end] = F.linear(h, self.w2_stacked[expert_id]) |
|
|
| |
| sorted_output = sorted_output * sorted_probs.unsqueeze(-1) |
|
|
| |
| output = torch.zeros_like(x_flat) |
| output.scatter_add_(0, sorted_token_ids.unsqueeze(-1).expand_as(sorted_output), sorted_output) |
|
|
| |
| for shared_expert_id in range(self.moe_shared_experts): |
| w1, w2, w3 = self.shared_experts[shared_expert_id] |
| expert_output = w2(F.gelu(w1(x_flat)) * w3(x_flat)) |
| output = output + expert_output |
|
|
| return output |
|
|
|
|
| class Block(nn.Module): |
| """Transformer block operating on unpadded (total_nnz, dim) tensors. |
| |
| Receives unpadding metadata (cu_seqlens, max_seqlen, indices, attn_mask) |
| from the Transformer level and passes them to attention. Norms and FFN |
| operate directly on the 2D unpadded tensor, avoiding wasted compute on |
| padding tokens. |
| """ |
| def __init__(self, config, layer_id: Optional[int] = None): |
| super().__init__() |
| self.is_first_block = (layer_id == 0) |
|
|
| self.attention = FlexBertUnpadAttention(config, layer_id=layer_id) |
| if config.use_moe: |
| self.ffn = FFNwMoE(config) |
| else: |
| self.ffn = FeedForward(config) |
|
|
| self.norm_attention = LayerNormClass(config.num_dims, eps=config.layernorm_eps) |
| self.norm_ffn = LayerNormClass(config.num_dims, eps=config.layernorm_eps) |
|
|
| def forward(self, x, cu_seqlens, max_seqlen, indices, attn_mask): |
| """ |
| Args: |
| x: (total_nnz, dim) - unpadded hidden states |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| indices: (total_nnz,) |
| attn_mask: (batch, seq_len) |
| |
| Returns: |
| x: (total_nnz, dim) |
| aux_loss: auxiliary loss from MoE or None |
| """ |
| if self.is_first_block: |
| attn_in = x |
| else: |
| attn_in = self.norm_attention(x) |
|
|
| x = x + self.attention( |
| attn_in, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| indices=indices, |
| attn_mask=attn_mask, |
| ) |
|
|
| ffn_out, aux_loss = self.ffn( |
| self.norm_ffn(x) |
| ) |
| x = x + ffn_out |
| return x, aux_loss |
|
|
|
|
|
|
| |
|
|
| class Transformer(nn.Module): |
| """ModernBERT-style Transformer: unpad once before embeddings, repad once at |
| the end. All blocks, norms, and FFNs operate on (total_nnz, dim) tensors, |
| avoiding wasted compute on padding tokens. |
| """ |
| def __init__(self, config): |
| super().__init__() |
|
|
| self.vocab_size = config.vocab_size |
| self.num_dims = config.num_dims |
| self.num_heads = config.num_heads |
| self.context_len = config.context_len |
| self.use_moe = config.use_moe |
| self.use_lossfreebalance = config.use_lossfreebalance and self.use_moe |
|
|
| self.num_layers = config.num_layers |
|
|
| hidden_dim = 4 * config.num_dims |
|
|
| self.tokens_embedding = nn.Embedding(config.vocab_size, config.num_dims) |
| self.norm_embeddings = LayerNormClass(config.num_dims, eps=config.layernorm_eps) |
|
|
| self.blocks = nn.ModuleList() |
| for layer_id in range(self.num_layers): |
| self.blocks.append(Block(config, layer_id=layer_id)) |
|
|
| self.norm = LayerNormClass(config.num_dims, eps=config.layernorm_eps) |
| self.ll_head = nn.Linear(config.num_dims, config.vocab_size, bias=False) |
|
|
| self.tokens_embedding.weight = self.ll_head.weight |
|
|
| def _unpad(self, input_ids, attention_mask): |
| """Compute unpadding metadata and unpad input_ids before embedding. |
| |
| Unpads input_ids (cheap 1D integer indexing) so that embedding and |
| all subsequent layers only process real tokens. |
| |
| Args: |
| input_ids: (batch, seq_len) |
| attention_mask: (batch, seq_len) or None |
| |
| Returns: |
| input_ids_unpadded: (total_nnz,) |
| indices: (total_nnz,) |
| cu_seqlens: (batch + 1,) |
| max_seqlen: int |
| attn_mask: (batch, seq_len) |
| batch_size: int |
| seq_len: int |
| """ |
| batch_size, seq_len = input_ids.shape |
|
|
| if attention_mask is None: |
| attn_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.int32) |
| else: |
| attn_mask = attention_mask.to(dtype=torch.int32) |
|
|
| |
| |
| input_ids_3d = input_ids.unsqueeze(-1).float() |
| input_ids_unpadded, indices, cu_seqlens, max_seqlen = bert_padding.unpad_input(input_ids_3d, attn_mask) |
| input_ids_unpadded = input_ids_unpadded.squeeze(-1).long() |
|
|
| return input_ids_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| targets: Optional[torch.Tensor] = None, |
| start_pos: int = 0, |
| attention_mask: Optional[torch.Tensor] = None, |
| ): |
| batch_size, seq_len = x.shape |
|
|
| |
| x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self._unpad(x, attention_mask) |
|
|
| |
| x = self.tokens_embedding(x_unpadded) |
| x = self.norm_embeddings(x) |
|
|
| total_aux_loss = 0 |
|
|
| for block in self.blocks: |
| x, aux_loss = block( |
| x, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| indices=indices, |
| attn_mask=attn_mask, |
| ) |
| if self.use_moe and not self.use_lossfreebalance: |
| total_aux_loss += aux_loss |
|
|
| x = self.norm(x) |
|
|
| |
| x = bert_padding.pad_input(x, indices, batch_size, seq_len) |
|
|
| logits = self.ll_head(x) |
|
|
| if targets is None: |
| loss = None |
| ce_loss = None |
| else: |
| c_batch_size, c_context_len, c_dim = logits.shape |
| logits = logits.view(c_batch_size * c_context_len, c_dim) |
| targets = targets.view(c_batch_size * c_context_len) |
| ce_loss = F.cross_entropy(logits, targets) |
|
|
| if self.use_moe and not self.use_lossfreebalance: |
| loss = ce_loss + total_aux_loss |
| else: |
| loss = ce_loss |
| ce_loss = aux_loss |
|
|
| return logits, loss, ce_loss |
|
|
| @torch.no_grad() |
| def generate(self, x: torch.Tensor, max_tokens: int, temperature: float = 1.0, top_k: int = 50, |
| use_cache: bool = False): |
| """Generate text from x up to max_tokens.""" |
| for c_tkn_pos in range(max_tokens): |
| if use_cache: |
| if c_tkn_pos == 0: |
| logits, _, ce_loss = self.forward(x, start_pos=c_tkn_pos) |
| else: |
| logits, _, ce_loss = self.forward(x[:, -1:], start_pos=c_tkn_pos) |
| else: |
| logits, _, ce_loss = self.forward(x) |
|
|
| logits = logits[:, -1, :] / temperature |
| if top_k is not None: |
| tkl, idx = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < tkl[:, [-1]]] = -float('Inf') |
|
|
| probs = F.softmax(logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| x = torch.cat((x, next_token), dim=1) |
| return x |
|
|
|
|
|
|
| |
|
|
| class CustomTransformerPreTrainedModel(PreTrainedModel): |
| """Base class for CustomTransformer models.""" |
| config_class = CustomTransformerConfig |
| base_model_prefix = "transformer" |
| supports_gradient_checkpointing = False |
| _no_split_modules = ["Block"] |
|
|
| def _init_weights(self, module): |
| """Initialize weights - handled by model itself.""" |
| pass |
|
|
|
|
| class CustomTransformerModel(CustomTransformerPreTrainedModel): |
| """The bare CustomTransformer Model outputting raw hidden-states.""" |
|
|
| def __init__(self, config: CustomTransformerConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.transformer = Transformer(config) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.transformer.tokens_embedding |
|
|
| def set_input_embeddings(self, value): |
| self.transformer.tokens_embedding = value |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| """Forward pass returning raw hidden states.""" |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self.transformer._unpad(input_ids, attention_mask) |
|
|
| |
| x = self.transformer.tokens_embedding(x_unpadded) |
| x = self.transformer.norm_embeddings(x) |
|
|
| for block in self.transformer.blocks: |
| x, _ = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask) |
|
|
| x = self.transformer.norm(x) |
|
|
| |
| hidden_states = bert_padding.pad_input(x, indices, batch_size, seq_len) |
|
|
| if not return_dict: |
| return (hidden_states,) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=None, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
|
|
| class CustomTransformerForMaskedLM(CustomTransformerPreTrainedModel): |
| """CustomTransformer Model with a masked language modeling head on top.""" |
| _tied_weights_keys = ["transformer.ll_head.weight", "transformer.tokens_embedding.weight"] |
|
|
| def __init__(self, config: CustomTransformerConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.transformer = Transformer(config) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.transformer.tokens_embedding |
|
|
| def set_input_embeddings(self, value): |
| self.transformer.tokens_embedding = value |
|
|
| def get_output_embeddings(self): |
| return self.transformer.ll_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.transformer.ll_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, MaskedLMOutput]: |
| """Forward pass for masked language modeling.""" |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| logits, model_loss, ce_loss = self.transformer( |
| input_ids, targets=labels, start_pos=0, attention_mask=attention_mask |
| ) |
|
|
| masked_lm_loss = None |
| if labels is not None: |
| masked_lm_loss = model_loss |
|
|
| if not return_dict: |
| output = (logits,) |
| return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=masked_lm_loss, |
| logits=logits, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
|
|
| class CustomTransformerForSequenceClassification(CustomTransformerPreTrainedModel): |
| """CustomTransformer Model with a sequence classification head on top.""" |
|
|
| def __init__(self, config: CustomTransformerConfig): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.config = config |
|
|
| self.transformer = Transformer(config) |
|
|
| |
| classifier_dropout = ( |
| config.classifier_dropout |
| if config.classifier_dropout is not None |
| else config.attention_probs_dropout_prob |
| ) |
| self.dropout = nn.Dropout(classifier_dropout) |
| self.classifier = nn.Linear(config.num_dims, config.num_labels) |
|
|
| self._init_classifier_weights() |
| self.post_init() |
|
|
| def _init_classifier_weights(self): |
| std = 0.02 |
| if isinstance(self.classifier, nn.Linear): |
| self.classifier.weight.data.normal_(mean=0.0, std=std) |
| if self.classifier.bias is not None: |
| self.classifier.bias.data.zero_() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| head_mask: Optional[torch.FloatTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, SequenceClassifierOutput]: |
| """Forward pass for sequence classification.""" |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
| |
| x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self.transformer._unpad(input_ids, attention_mask) |
|
|
| |
| x = self.transformer.tokens_embedding(x_unpadded) |
| x = self.transformer.norm_embeddings(x) |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (bert_padding.pad_input(x, indices, batch_size, seq_len),) |
|
|
| for block in self.transformer.blocks: |
| x, _ = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (bert_padding.pad_input(x, indices, batch_size, seq_len),) |
|
|
| x = self.transformer.norm(x) |
|
|
| |
| hidden_states = bert_padding.pad_input(x, indices, batch_size, seq_len) |
|
|
| |
| pooled_output = hidden_states[:, 0, :] |
| pooled_output = self.dropout(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = nn.MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = nn.BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
|
|
| if not return_dict: |
| output = (logits,) + (all_hidden_states,) + (None,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=all_hidden_states, |
| attentions=None, |
| ) |
|
|
|
|
|
|
| |
|
|
| try: |
| from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, AutoModelForSequenceClassification |
|
|
| AutoConfig.register("custom_transformer", CustomTransformerConfig) |
| AutoModel.register(CustomTransformerConfig, CustomTransformerModel) |
| AutoModelForMaskedLM.register(CustomTransformerConfig, CustomTransformerForMaskedLM) |
| AutoModelForSequenceClassification.register(CustomTransformerConfig, CustomTransformerForSequenceClassification) |
| except Exception: |
| pass |
|
|
|
|
| def main(): |
| pass |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|