import math from typing import List, Optional, Tuple, Union import torch import torch.nn as nn from torch.nn import functional as F from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) from transformers.utils import logging from einops import rearrange, repeat from torch.amp import autocast from torch import einsum from transformers.generation import GenerationMixin from transformers.modeling_utils import PreTrainedModel from .configuration_rotary_indictrans import RotaryIndicTransConfig logger = logging.get_logger(__name__) device = "cuda" if torch.cuda.is_available() else "cpu" try: from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import ( index_first_axis, pad_input, unpad_input, ) except ImportError: logger.warning( "It is highly recommended to use `flash_attention_2` for better performance with RotaryIndicTrans." "Falling back to the default `eager` implementation." ) # Copied from transformers.models.llama.modeling_llama._get_unpad_data def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) return ( indices, cu_seqlens, max_seqlen_in_batch, ) # Copied from transformers.models.bart.modeling_bart.shift_tokens_right def shift_tokens_right( input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int ): shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids def create_position_ids_from_input_ids( input_ids, padding_idx, past_key_values_length=0 ): mask = input_ids.ne(padding_idx).int() incremental_indices = ( torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length ) * mask return incremental_indices.long() + padding_idx def rotate_half(x): x = rearrange(x, "... (d r) -> ... d r", r=2) x1, x2 = x.unbind(dim=-1) x = torch.stack((-x2, x1), dim=-1) return rearrange(x, "... d r -> ... (d r)") @autocast("cuda", enabled=False) def apply_rotary_emb(cos, sin, t): rot_dim = cos.shape[-1] assert rot_dim <= t.shape[-1] and cos.shape == sin.shape t_left, t_right = t[..., :rot_dim], t[..., rot_dim:] t_transformed = (t_left * cos) + (rotate_half(t_left) * sin) return torch.cat((t_transformed, t_right), dim=-1).type(t.dtype) class RotaryEmbedding(torch.nn.Module): def __init__( self, dim, theta=10000, interpolate_factor=1.0, cache_max_seq_len=8192 ): super().__init__() freqs_ = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.cache_max_seq_len = cache_max_seq_len self.interpolate_factor = interpolate_factor self.freqs = torch.nn.Parameter(freqs_, requires_grad=False).to(device) self.apply_rotary_emb = staticmethod(apply_rotary_emb) self.precompute_freqs(cache_max_seq_len) def precompute_freqs(self, max_seq_len): thetas = self.forward(max_seq_len, device=device) self.register_buffer("cached_cos", thetas.cos(), persistent=False) self.register_buffer("cached_sin", thetas.sin(), persistent=False) def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0): seq_len = t.shape[seq_dim] if seq_len > self.cache_max_seq_len: self.cache_max_seq_len = seq_len * 2 self.precompute_freqs(self.cache_max_seq_len) cos, sin = ( self.cached_cos[offset : (offset + seq_len)], self.cached_sin[offset : (offset + seq_len)], ) return apply_rotary_emb(cos, sin, t) @autocast("cuda", enabled=False) def forward(self, seq_len, device): seq = torch.arange(seq_len, device=device) / self.interpolate_factor thetas = einsum("..., f -> ... f", seq, self.freqs) thetas = repeat(thetas, "... n -> ... (n r)", r=2) return thetas # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->RotaryIndicTrans class RotaryIndicTransAttention(nn.Module): def __init__( self, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, is_causal: bool = False, is_cross_attention: bool = False, config: Optional[RotaryIndicTransConfig] = None, ): super().__init__() self.config = config self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads if (self.head_dim * num_heads) != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" f" and `num_heads`: {num_heads})." ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal # partial rotation in RoPE self.rotary_pos_embed = ( RotaryEmbedding( dim=self.head_dim // 2, theta=config.rope_args.get("theta", 10000), interpolate_factor=config.rope_args.get("interpolate_factor", 1.0), ) if not is_cross_attention else None ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return ( tensor.view(bsz, seq_len, self.num_heads, self.head_dim) .transpose(1, 2) .contiguous() ) def _apply_rotary_pos_emb(self, q, k, is_inference=False): q = rearrange(q, "(b h) t d -> b h t d", h=self.num_heads) k = rearrange(k, "(b h) t d -> b h t d", h=self.num_heads) offset = (k.shape[-2] - 1) if is_inference else 0 q = self.rotary_pos_embed.rotate_queries_or_keys(q, offset=offset) k = self.rotary_pos_embed.rotate_queries_or_keys(k) q = rearrange(q, "b h t d -> (b h) t d") k = rearrange(k, "b h t d -> (b h) t d") return q, k def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) * self.scaling if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1] ): key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: past_key_value = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) src_len = key_states.size(1) if self.rotary_pos_embed is not None: query_states, key_states = self._apply_rotary_pos_emb( query_states, key_states, is_inference=past_key_value is not None ) attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): raise ValueError( f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" f" {attn_weights.size()}" ) if attention_mask is not None: if attention_mask.size() != (bsz, 1, tgt_len, src_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) attn_weights = ( attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if layer_head_mask is not None: if layer_head_mask.size() != (self.num_heads,): raise ValueError( f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" f" {layer_head_mask.size()}" ) attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view( bsz, self.num_heads, tgt_len, src_len ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if output_attentions: attn_weights_reshaped = attn_weights.view( bsz, self.num_heads, tgt_len, src_len ) attn_weights = attn_weights_reshaped.view( bsz * self.num_heads, tgt_len, src_len ) else: attn_weights_reshaped = None attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.bmm(attn_probs, value_states) if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = rearrange( attn_output, "(b h) t d -> b t (h d)", h=self.num_heads, d=self.head_dim ) attn_output = self.out_proj(attn_output) return attn_output, attn_weights_reshaped, past_key_value class RotaryIndicTransFlashAttention2(RotaryIndicTransAttention): # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # RotaryIndicTransFlashAttention2 attention does not support output_attentions if output_attentions: raise ValueError( "RotaryIndicTransFlashAttention2 attention does not support output_attentions" ) is_cross_attention = key_value_states is not None bsz, q_len, _ = hidden_states.size() query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1] ): key_states = past_key_value[0].transpose(1, 2) value_states = past_key_value[1].transpose(1, 2) elif is_cross_attention: key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat( [past_key_value[0].transpose(1, 2), key_states], dim=1 ) value_states = torch.cat( [past_key_value[1].transpose(1, 2), value_states], dim=1 ) else: key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() # Handle the case where the model is quantized elif hasattr(self.config, "_pre_quantization_dtype"): target_dtype = self.config._pre_quantization_dtype else: target_dtype = self.q_proj.weight.dtype logger.warning_once( f"The input hidden states seems to be silently casted in float32, this might be related to" f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" f" {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) if self.rotary_pos_embed is not None: query_states, key_states = self._apply_rotary_pos_emb( query_states, key_states, is_inference=past_key_value is not None ) attn_output = self._flash_attention_forward( query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout, ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None, ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token first unpad the input, then computes the attention scores and pad the final attention scores. Args: query_states (`torch.Tensor`): Input query states to be passed to Flash Attention API key_states (`torch.Tensor`): Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`float`): Attention dropout softmax_scale (`float`, *optional*): The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] ( query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens, ) = self._upad_input( query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens attn_output_unpad = flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, causal=self.is_causal, ) attn_output = pad_input( attn_output_unpad, indices_q, batch_size, query_length ) else: attn_output = flash_attn_func( query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=self.is_causal, ) return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input def _upad_input( self, query_layer, key_layer, value_layer, attention_mask, query_length ): indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) value_layer = index_first_axis( value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k, ) if query_length == kv_seq_len: query_layer = index_first_axis( query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k indices_q = indices_k elif query_length == 1: max_seqlen_in_batch_q = 1 cu_seqlens_q = torch.arange( batch_size + 1, dtype=torch.int32, device=query_layer.device ) indices_q = cu_seqlens_q[:-1] query_layer = query_layer.squeeze(1) else: attention_mask = attention_mask[:, -query_length:] query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( query_layer, attention_mask ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) class RotaryIndicTransSdpaAttention(RotaryIndicTransAttention): def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions or layer_head_mask is not None: logger.warning_once( "RotaryIndicTransModel is using RotaryIndicTransSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) return super().forward( hidden_states, key_value_states=key_value_states, past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) if ( is_cross_attention and past_key_value is not None and past_key_value[0].shape[2] == key_value_states.shape[1] ): key_states = past_key_value[0] value_states = past_key_value[1] elif is_cross_attention: key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) elif past_key_value is not None: key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) if self.is_decoder: past_key_value = (key_states, value_states) query_states = self._shape(query_states, tgt_len, bsz) if self.rotary_pos_embed is not None: query_states, key_states = self._apply_rotary_pos_emb( query_states, key_states, is_inference=past_key_value is not None ) attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.is_causal and attention_mask is None and tgt_len > 1, ) if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): raise ValueError( f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" f" {attn_output.size()}" ) attn_output = rearrange( attn_output, "b h t d -> b t (h d)", h=self.num_heads, d=self.head_dim ) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value ROTARY_INDICTRANS_ATTENTION_CLASSES = { "eager": RotaryIndicTransAttention, "sdpa": RotaryIndicTransSdpaAttention, "flash_attention_2": RotaryIndicTransFlashAttention2, } # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->RotaryIndicTrans class RotaryIndicTransEncoderLayer(nn.Module): def __init__(self, config: RotaryIndicTransConfig): super().__init__() self.embed_dim = config.encoder_embed_dim self.self_attn = ROTARY_INDICTRANS_ATTENTION_CLASSES[ config._attn_implementation ]( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) self.normalize_before = config.encoder_normalize_before def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, layer_head_mask: torch.Tensor, output_attentions: bool = False, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states, attn_weights, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states if self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = F.dropout( hidden_states, p=self.activation_dropout, training=self.training ) hidden_states = self.fc2(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() ): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp( hidden_states, min=-clamp_value, max=clamp_value ) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->RotaryIndicTrans class RotaryIndicTransDecoderLayer(nn.Module): def __init__(self, config: RotaryIndicTransConfig): super().__init__() self.embed_dim = config.decoder_embed_dim self.self_attn = ROTARY_INDICTRANS_ATTENTION_CLASSES[ config._attn_implementation ]( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, is_causal=True, config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn = ROTARY_INDICTRANS_ATTENTION_CLASSES[ config._attn_implementation ]( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_cross_attention=True, is_decoder=True, config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) self.normalize_before = config.decoder_normalize_before def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, ) -> torch.Tensor: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. encoder_hidden_states (`torch.FloatTensor`): cross attention input to the layer of shape `(batch, seq_len, embed_dim)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. """ residual = hidden_states if self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) self_attn_past_key_value = ( past_key_value[:2] if past_key_value is not None else None ) hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, ) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.self_attn_layer_norm(hidden_states) cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states if self.normalize_before: hidden_states = self.encoder_attn_layer_norm(hidden_states) cross_attn_past_key_value = ( past_key_value[-2:] if past_key_value is not None else None ) ( hidden_states, cross_attn_weights, cross_attn_present_key_value, ) = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = F.dropout( hidden_states, p=self.dropout, training=self.training ) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.encoder_attn_layer_norm(hidden_states) present_key_value = present_key_value + cross_attn_present_key_value residual = hidden_states if self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = F.dropout( hidden_states, p=self.activation_dropout, training=self.training ) hidden_states = self.fc2(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states if not self.normalize_before: hidden_states = self.final_layer_norm(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (self_attn_weights, cross_attn_weights) if use_cache: outputs += (present_key_value,) return outputs # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100PretrainedModel->RotaryIndicTrans class RotaryIndicTransPreTrainedModel(PreTrainedModel): config_class = RotaryIndicTransConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["RotaryIndicTransAttention"] def _init_weights(self, module): std = self.config.init_std if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (RotaryIndicTransDecoder, RotaryIndicTransEncoder)): module.gradient_checkpointing = value # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100EncoderLayer->RotaryIndicTrans class RotaryIndicTransEncoder(RotaryIndicTransPreTrainedModel): def __init__( self, config: RotaryIndicTransConfig, embed_tokens: Optional[nn.Embedding] = None, ): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop embed_dim = config.encoder_embed_dim self.padding_idx = config.pad_token_id self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_tokens = nn.Embedding( config.encoder_vocab_size, embed_dim, self.padding_idx ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight self.layers = nn.ModuleList( [RotaryIndicTransEncoderLayer(config) for _ in range(config.encoder_layers)] ) self.layer_norm = ( nn.LayerNorm(embed_dim) if config.encoder_normalize_before else None ) self.layernorm_embedding = ( nn.LayerNorm(embed_dim) if config.layernorm_embedding else None ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time" ) elif input_ids is not None: self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale hidden_states = inputs_embeds if self.layernorm_embedding is not None: hidden_states = self.layernorm_embedding(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) if attention_mask is not None: if self._use_flash_attention_2: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: attention_mask = _prepare_4d_attention_mask_for_sdpa( attention_mask, inputs_embeds.dtype ) else: attention_mask = _prepare_4d_attention_mask( attention_mask, inputs_embeds.dtype ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None if head_mask is not None: if head_mask.size()[0] != len(self.layers): raise ValueError( f"The head_mask should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) dropout_probability = torch.rand([]) skip_the_layer = ( True if self.training and (dropout_probability < self.layerdrop) else False ) if not skip_the_layer or deepspeed_zero3_is_enabled: if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(encoder_layer), hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=( head_mask[idx] if head_mask is not None else None ), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) if self.layer_norm is not None: hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None ) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100DecoderLayer->RotaryIndicTrans class RotaryIndicTransDecoder(RotaryIndicTransPreTrainedModel): def __init__( self, config: RotaryIndicTransConfig, embed_tokens: Optional[nn.Embedding] = None, ): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop embed_dim = config.encoder_embed_dim self.padding_idx = config.pad_token_id self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 self.embed_tokens = nn.Embedding( config.decoder_vocab_size, embed_dim, self.padding_idx ) if embed_tokens is not None: self.embed_tokens.weight = embed_tokens.weight self.layers = nn.ModuleList( [RotaryIndicTransDecoderLayer(config) for _ in range(config.decoder_layers)] ) self.layer_norm = ( nn.LayerNorm(embed_dim) if config.decoder_normalize_before else None ) self.layernorm_embedding = ( nn.LayerNorm(embed_dim) if config.layernorm_embedding else None ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False self.post_init() def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*): Mask to avoid performing cross-attention on padding tokens indices of encoder input_ids. Mask values selected in `[0, 1]`: - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): Mask to nullify selected heads of the cross-attention modules in the decoder to avoid performing cross-attention on hidden heads. Mask values selected in `[0, 1]`: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if input_ids is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" ) elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError( "You have to specify either decoder_input_ids or decoder_inputs_embeds" ) past_key_values_length = ( past_key_values[0][0].shape[2] if past_key_values is not None else 0 ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale if self._use_flash_attention_2: attention_mask = ( attention_mask if (attention_mask is not None and 0 in attention_mask) else None ) elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( attention_mask, input_shape, inputs_embeds, past_key_values_length, ) else: attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length ) if encoder_hidden_states is not None and encoder_attention_mask is not None: if self._use_flash_attention_2: encoder_attention_mask = ( encoder_attention_mask if 0 in encoder_attention_mask else None ) elif ( self._use_sdpa and cross_attn_head_mask is None and not output_attentions ): encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1], ) else: encoder_attention_mask = _prepare_4d_attention_mask( encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] ) hidden_states = inputs_embeds if self.layernorm_embedding is not None: hidden_states = self.layernorm_embedding(hidden_states) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting" " `use_cache=False`..." ) use_cache = False all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None next_decoder_cache = () if use_cache else None for attn_mask, mask_name in zip( [head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"] ): if attn_mask is not None: if attn_mask.size()[0] != len(self.layers): raise ValueError( f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" f" {head_mask.size()[0]}." ) deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() for idx, decoder_layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (hidden_states,) dropout_probability = torch.rand([]) skip_the_layer = ( True if self.training and (dropout_probability < self.layerdrop) else False ) if not skip_the_layer or deepspeed_zero3_is_enabled: past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training: def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, use_cache) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(decoder_layer), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, ( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), None, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=( head_mask[idx] if head_mask is not None else None ), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if skip_the_layer: continue if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) if output_attentions: all_self_attns += (layer_outputs[1],) all_cross_attentions += (layer_outputs[2],) if self.layer_norm is not None: hidden_states = self.layer_norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None if not return_dict: return tuple( v for v in [ hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions, ] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, cross_attentions=all_cross_attentions, ) # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100Model->RotaryIndicTrans class RotaryIndicTransModel(RotaryIndicTransPreTrainedModel): def __init__(self, config: RotaryIndicTransConfig): super().__init__(config) self.encoder = RotaryIndicTransEncoder(config) self.decoder = RotaryIndicTransDecoder(config) self.post_init() def get_encoder(self): return self.encoder def get_decoder(self): return self.decoder def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_outputs[0], encoder_attention_mask=attention_mask, head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if not return_dict: return decoder_outputs + encoder_outputs return Seq2SeqModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ForConditionalGeneration->RotaryIndicTrans class RotaryIndicTransForConditionalGeneration( RotaryIndicTransPreTrainedModel, GenerationMixin ): base_model_prefix = "model" _tied_weights_keys = ["decoder.embed_tokens.weight", "lm_head.weight"] def __init__(self, config: RotaryIndicTransConfig): super().__init__(config) self.model = RotaryIndicTransModel(config) self.lm_head = nn.Linear( config.decoder_embed_dim, config.decoder_vocab_size, bias=False ) self.post_init() def get_encoder(self): return self.model.encoder def get_decoder(self): return self.model.decoder def get_input_embeddings(self): return self.model.encoder.embed_tokens def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def tie_weights(self): if self.config.share_decoder_input_output_embed: self._tie_or_clone_weights(self.model.decoder.embed_tokens, self.lm_head) def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, decoder_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. Returns: """ return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) if labels is not None: if decoder_input_ids is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, encoder_outputs=encoder_outputs, decoder_attention_mask=decoder_attention_mask, head_mask=head_mask, decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) lm_logits = self.lm_head(outputs[0]) masked_lm_loss = None if labels is not None: labels = labels.to(lm_logits.device) masked_lm_loss = F.cross_entropy( input=lm_logits.view(-1, self.config.decoder_vocab_size), target=labels.view(-1), ) if not return_dict: output = (lm_logits,) + outputs[1:] return ( ((masked_lm_loss,) + output) if masked_lm_loss is not None else output ) return Seq2SeqLMOutput( loss=masked_lm_loss, logits=lm_logits, past_key_values=outputs.past_key_values, decoder_hidden_states=outputs.decoder_hidden_states, decoder_attentions=outputs.decoder_attentions, cross_attentions=outputs.cross_attentions, encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_hidden_states=outputs.encoder_hidden_states, encoder_attentions=outputs.encoder_attentions, ) def prepare_inputs_for_generation( self, decoder_input_ids, past_key_values=None, attention_mask=None, head_mask=None, decoder_head_mask=None, cross_attn_head_mask=None, use_cache=None, encoder_outputs=None, **kwargs, ): if past_key_values is not None: decoder_input_ids = decoder_input_ids[:, -1:] return { "input_ids": None, "encoder_outputs": encoder_outputs, "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, "decoder_head_mask": decoder_head_mask, "cross_attn_head_mask": cross_attn_head_mask, "use_cache": use_cache, } @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx) for past_state in layer_past ), ) return reordered_past