| | from dataclasses import dataclass |
| | from functools import partial |
| | from typing import Optional, Tuple, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from transformers import AutoConfig, AutoModelForCausalLM |
| | from transformers.cache_utils import DynamicCache |
| | from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPast, |
| | CausalLMOutputWithPast, |
| | ModelOutput, |
| | ) |
| | from transformers.processing_utils import Unpack |
| | from transformers.utils import logging |
| |
|
| | from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM |
| |
|
| | try: |
| | from torch.nn.attention.flex_attention import BlockMask |
| | except ImportError: |
| | BlockMask = None |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class EncoderBaseModelOutputWithPast(ModelOutput): |
| | """Custom (encoder) model output. |
| | Stores previous decoder and updated encoder cache and encoder last hidden state. |
| | """ |
| |
|
| | past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = ( |
| | None |
| | ) |
| | encoder_last_hidden_state: Optional[torch.FloatTensor] = None |
| | encoder_past_key_values: Optional[ |
| | Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache] |
| | ] = None |
| |
|
| |
|
| | @dataclass |
| | class DecoderCausalLMOutputWithPast(ModelOutput): |
| | """Custom (decoder) model output. |
| | Stores previous encoder and updated decoder cache and decoder logits. |
| | """ |
| |
|
| | logits: Optional[torch.FloatTensor] = None |
| | past_key_values: Optional[Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache]] = ( |
| | None |
| | ) |
| | encoder_past_key_values: Optional[ |
| | Union[Tuple[Tuple[torch.FloatTensor]], DynamicCache] |
| | ] = None |
| |
|
| |
|
| | class LLMasEncoderDecoder(nn.Module): |
| | def __init__( |
| | self, |
| | pretrained_model_name_or_path: str, |
| | max_length: int, |
| | attn_backend: str = "sdpa", |
| | freeze_encoder: bool = False, |
| | reinit_encoder: bool = False, |
| | reinit_decoder: bool = False, |
| | tie_encoder_decoder_weights: bool = False, |
| | use_encoder_causal_mask: bool = False, |
| | num_encoder_layers: int = -1, |
| | num_decoder_layers: int = -1, |
| | keep_top_encoder_layers: bool = False, |
| | keep_top_decoder_layers: bool = False, |
| | use_gradient_checkpointing: bool = False, |
| | **llm_init_kwargs, |
| | ): |
| | assert not (tie_encoder_decoder_weights and reinit_decoder), ( |
| | "Cannot tie encoder-decoder weights and reinitialize decoder." |
| | ) |
| | assert not (tie_encoder_decoder_weights and freeze_encoder), ( |
| | "Cannot freeze encoder weights when tying encoder-decoder weights." |
| | ) |
| | super().__init__() |
| | self.use_encoder_causal_mask = use_encoder_causal_mask |
| | self.tie_encoder_decoder_weights = tie_encoder_decoder_weights |
| |
|
| | if reinit_encoder: |
| | assert num_encoder_layers > 0 |
| | encoder_config = AutoConfig.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | num_hidden_layers=num_encoder_layers, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | self.encoder = CustomQwen3ForCausalLM(encoder_config) |
| | else: |
| | self.encoder = CustomQwen3ForCausalLM.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | assert num_encoder_layers <= len(self.encoder.model.layers), ( |
| | f"Cannot keep {num_encoder_layers} layers. " |
| | f"Pre-trained model only has {len(self.encoder.model.layers)} layers." |
| | ) |
| | num_encoder_layers = ( |
| | len(self.encoder.model.layers) |
| | if num_encoder_layers == -1 |
| | else num_encoder_layers |
| | ) |
| | if keep_top_encoder_layers: |
| | self.encoder.model.layers = self.encoder.model.layers[ |
| | -num_encoder_layers: |
| | ] |
| | else: |
| | self.encoder.model.layers = self.encoder.model.layers[ |
| | :num_encoder_layers |
| | ] |
| |
|
| | if freeze_encoder: |
| | for name, param in self.encoder.named_parameters(): |
| | if "embed_tokens" not in name: |
| | param.requires_grad = False |
| | if use_gradient_checkpointing: |
| | self.encoder.gradient_checkpointing_enable() |
| |
|
| | if tie_encoder_decoder_weights: |
| | self.decoder = self.encoder |
| | num_decoder_layers = ( |
| | len(self.decoder.model.layers) |
| | if num_decoder_layers == -1 |
| | else num_decoder_layers |
| | ) |
| | assert num_decoder_layers <= len(self.decoder.model.layers), ( |
| | f"Cannot keep {num_decoder_layers} layers. " |
| | f"Pre-trained model only has {len(self.decoder.model.layers)} layers." |
| | ) |
| | |
| | self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[ |
| | -num_decoder_layers: |
| | ] |
| |
|
| | else: |
| | if reinit_decoder: |
| | assert num_decoder_layers > 0 |
| | decoder_config = AutoConfig.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | num_hidden_layers=num_decoder_layers, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | self.decoder = CustomQwen3ForCausalLM(decoder_config) |
| | else: |
| | self.decoder = CustomQwen3ForCausalLM.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | assert num_decoder_layers <= len(self.decoder.model.layers), ( |
| | f"Cannot keep {num_decoder_layers} layers. " |
| | f"Pre-trained model only has {len(self.decoder.layers)} layers." |
| | ) |
| | if keep_top_decoder_layers: |
| | self.decoder.model.layers = self.decoder.model.layers[ |
| | -num_decoder_layers: |
| | ] |
| | else: |
| | self.decoder.model.layers = self.decoder.model.layers[ |
| | :num_decoder_layers |
| | ] |
| | del self.decoder.model.embed_tokens |
| | |
| | |
| | if ( |
| | self.encoder.lm_head.weight.data_ptr() |
| | == self.encoder.model.embed_tokens.weight.data_ptr() |
| | ): |
| | self.decoder.lm_head = self.encoder.lm_head |
| | else: |
| | del self.encoder.lm_head |
| | if use_gradient_checkpointing: |
| | self.decoder.gradient_checkpointing_enable() |
| | self.max_length = max_length |
| |
|
| | def freeze_encoder(self): |
| | for p in self.encoder.model.parameters(): |
| | p.requires_grad = False |
| |
|
| | def unfreeze_encoder(self): |
| | for p in self.encoder.model.parameters(): |
| | p.requires_grad = True |
| |
|
| | |
| | def forward( |
| | self, |
| | |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[DynamicCache] = None, |
| | encoder_last_hidden_state: Optional[torch.FloatTensor] = None, |
| | |
| | encoder_input_ids: Optional[torch.LongTensor] = None, |
| | encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
| | encoder_position_ids: Optional[torch.LongTensor] = None, |
| | encoder_cache_position: Optional[torch.LongTensor] = None, |
| | encoder_past_key_values: Optional[DynamicCache] = None, |
| | |
| | fix_cache_length: bool = True, |
| | return_updated_cache: bool = False, |
| | **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| | ) -> Union[DecoderCausalLMOutputWithPast, EncoderBaseModelOutputWithPast]: |
| | |
| | |
| | new_seen_tokens = ( |
| | 0 |
| | if encoder_last_hidden_state is None |
| | else encoder_last_hidden_state.shape[1] |
| | ) |
| | |
| | if encoder_input_ids is not None: |
| | if self.use_encoder_causal_mask: |
| | encoder_attention_mask = None |
| | if encoder_cache_position is None and encoder_position_ids is not None: |
| | encoder_cache_position = encoder_position_ids[0] |
| | encoder_output = self.encoder.model( |
| | input_ids=encoder_input_ids, |
| | attention_mask=encoder_attention_mask, |
| | position_ids=encoder_position_ids, |
| | use_cache=True, |
| | past_key_values=encoder_past_key_values, |
| | cache_position=encoder_cache_position, |
| | ) |
| | if return_updated_cache: |
| | |
| | return EncoderBaseModelOutputWithPast( |
| | encoder_last_hidden_state=encoder_output.last_hidden_state, |
| | encoder_past_key_values=encoder_output.past_key_values, |
| | past_key_values=past_key_values, |
| | ) |
| | encoder_last_hidden_state = encoder_output.last_hidden_state |
| |
|
| | |
| | if encoder_last_hidden_state is None: |
| | q_start_idx = 0 |
| | decoder_hidden_states = self.encoder.model.embed_tokens(input_ids) |
| | if cache_position is None: |
| | if position_ids is not None: |
| | cache_position = position_ids[0] |
| | else: |
| | past_seen_tokens = ( |
| | past_key_values.get_seq_length() |
| | if past_key_values is not None |
| | else 0 |
| | ) |
| | cache_position = torch.arange( |
| | past_seen_tokens, |
| | past_seen_tokens + decoder_hidden_states.shape[1], |
| | device=decoder_hidden_states.device, |
| | ) |
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| | decoder_position_embeddings = self.decoder.model.rotary_emb( |
| | decoder_hidden_states, position_ids |
| | ) |
| | else: |
| | q_start_idx = encoder_last_hidden_state.shape[1] |
| | decoder_hidden_states = self.encoder.model.embed_tokens(input_ids) |
| | decoder_hidden_states = torch.cat( |
| | [ |
| | encoder_last_hidden_state, |
| | decoder_hidden_states, |
| | ], |
| | dim=1, |
| | ) |
| | if cache_position is None: |
| | if position_ids is not None: |
| | cache_position = position_ids[0] |
| | else: |
| | past_seen_tokens = ( |
| | past_key_values.get_seq_length() |
| | if past_key_values is not None |
| | else 0 |
| | ) |
| | cache_position = torch.cat( |
| | [ |
| | torch.arange( |
| | past_seen_tokens, |
| | past_seen_tokens + encoder_last_hidden_state.shape[1], |
| | device=decoder_hidden_states.device, |
| | ), |
| | torch.arange( |
| | past_seen_tokens + new_seen_tokens, |
| | past_seen_tokens + new_seen_tokens + input_ids.shape[1], |
| | device=decoder_hidden_states.device, |
| | ), |
| | ], |
| | dim=-1, |
| | ) |
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| | decoder_position_embeddings = self.decoder.model.rotary_emb( |
| | decoder_hidden_states, position_ids |
| | ) |
| |
|
| | if hasattr(self.decoder.model, "_update_causal_mask"): |
| | |
| | attention_mask = self.decoder.model._update_causal_mask( |
| | attention_mask=attention_mask, |
| | input_tensor=decoder_hidden_states, |
| | cache_position=cache_position, |
| | past_key_values=past_key_values, |
| | output_attentions=False, |
| | ) |
| | for decoder_layer in self.decoder.model.layers: |
| | layer_idx = decoder_layer.self_attn.layer_idx |
| | if ( |
| | self.tie_encoder_decoder_weights |
| | and layer_idx not in self.decoder_layer_idxs |
| | ): |
| | continue |
| | |
| | |
| | if past_key_values is not None and len(past_key_values) > layer_idx: |
| | prev_cache_len = past_key_values[layer_idx][0].shape[-2] |
| | else: |
| | prev_cache_len = 0 |
| | cache_len = prev_cache_len + new_seen_tokens |
| |
|
| | if self.decoder.model.gradient_checkpointing and self.training: |
| | |
| | decoder_hidden_states = self.decoder._gradient_checkpointing_func( |
| | partial(decoder_layer.__call__, **flash_attn_kwargs), |
| | decoder_hidden_states, |
| | attention_mask, |
| | position_ids, |
| | past_key_values, |
| | False, |
| | True, |
| | cache_position, |
| | decoder_position_embeddings, |
| | q_start_idx, |
| | )[0] |
| | else: |
| | decoder_hidden_states = decoder_layer( |
| | hidden_states=decoder_hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_values, |
| | output_attentions=False, |
| | use_cache=True, |
| | cache_position=cache_position, |
| | position_embeddings=decoder_position_embeddings, |
| | q_start_idx=q_start_idx, |
| | **flash_attn_kwargs, |
| | )[0] |
| | |
| | if q_start_idx > 0: |
| | decoder_hidden_states = torch.cat( |
| | [ |
| | encoder_last_hidden_state, |
| | decoder_hidden_states, |
| | ], |
| | dim=1, |
| | ) |
| |
|
| | if past_key_values is not None: |
| | |
| | |
| | past_key_values.key_cache[layer_idx] = past_key_values.key_cache[ |
| | layer_idx |
| | ][..., :cache_len, :] |
| | past_key_values.value_cache[layer_idx] = past_key_values.value_cache[ |
| | layer_idx |
| | ][..., :cache_len, :] |
| | decoder_hidden_states = self.decoder.model.norm( |
| | decoder_hidden_states[:, q_start_idx:, :] |
| | ) |
| | logits = self.decoder.lm_head(decoder_hidden_states) |
| | return DecoderCausalLMOutputWithPast( |
| | logits=logits, |
| | past_key_values=past_key_values, |
| | encoder_past_key_values=encoder_past_key_values, |
| | |
| | |
| | ) |
| |
|
| |
|
| | class LLMasEncoderDecoderShareKV(nn.Module): |
| | def __init__( |
| | self, |
| | pretrained_model_name_or_path: str, |
| | max_length: int, |
| | attn_backend: str = "sdpa", |
| | freeze_encoder: bool = False, |
| | reinit_encoder: bool = False, |
| | reinit_decoder: bool = False, |
| | tie_encoder_decoder_weights: bool = False, |
| | use_encoder_causal_mask: bool = False, |
| | num_encoder_layers: int = -1, |
| | num_decoder_layers: int = -1, |
| | keep_top_encoder_layers: bool = False, |
| | keep_top_decoder_layers: bool = False, |
| | use_gradient_checkpointing: bool = False, |
| | **llm_init_kwargs, |
| | ): |
| | assert not (tie_encoder_decoder_weights and reinit_decoder), ( |
| | "Cannot tie encoder-decoder weights and reinitialize decoder." |
| | ) |
| | assert not (tie_encoder_decoder_weights and freeze_encoder), ( |
| | "Cannot freeze encoder weights when tying encoder-decoder weights." |
| | ) |
| | super().__init__() |
| | self.use_encoder_causal_mask = use_encoder_causal_mask |
| | self.tie_encoder_decoder_weights = tie_encoder_decoder_weights |
| |
|
| | if reinit_encoder: |
| | assert num_encoder_layers > 0 |
| | encoder_config = AutoConfig.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | num_hidden_layers=num_encoder_layers, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | self.encoder = AutoModelForCausalLM.from_config(encoder_config) |
| | else: |
| | self.encoder = AutoModelForCausalLM.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | assert num_encoder_layers <= len(self.encoder.model.layers), ( |
| | f"Cannot keep {num_encoder_layers} layers. " |
| | f"Pre-trained model only has {len(self.encoder.model.layers)} layers." |
| | ) |
| | num_encoder_layers = ( |
| | len(self.encoder.model.layers) |
| | if num_encoder_layers == -1 |
| | else num_encoder_layers |
| | ) |
| | if keep_top_encoder_layers: |
| | self.encoder.model.layers = self.encoder.model.layers[ |
| | -num_encoder_layers: |
| | ] |
| | else: |
| | self.encoder.model.layers = self.encoder.model.layers[ |
| | :num_encoder_layers |
| | ] |
| |
|
| | if freeze_encoder: |
| | for name, param in self.encoder.named_parameters(): |
| | if "embed_tokens" not in name: |
| | param.requires_grad = False |
| | if use_gradient_checkpointing: |
| | self.encoder.gradient_checkpointing_enable() |
| |
|
| | if tie_encoder_decoder_weights: |
| | self.decoder = self.encoder |
| | num_decoder_layers = ( |
| | len(self.decoder.model.layers) |
| | if num_decoder_layers == -1 |
| | else num_decoder_layers |
| | ) |
| | assert num_decoder_layers <= len(self.decoder.model.layers), ( |
| | f"Cannot keep {num_decoder_layers} layers. " |
| | f"Pre-trained model only has {len(self.decoder.model.layers)} layers." |
| | ) |
| | |
| | self.decoder_layer_idxs = list(range(len(self.encoder.model.layers)))[ |
| | -num_decoder_layers: |
| | ] |
| |
|
| | else: |
| | if reinit_decoder: |
| | assert num_decoder_layers > 0 |
| | decoder_config = AutoConfig.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | num_hidden_layers=num_decoder_layers, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | self.decoder = AutoModelForCausalLM(decoder_config) |
| | else: |
| | self.decoder = AutoModelForCausalLM.from_pretrained( |
| | pretrained_model_name_or_path, |
| | trust_remote_code=True, |
| | attn_implementation=attn_backend, |
| | **llm_init_kwargs, |
| | ) |
| | assert num_decoder_layers <= len(self.decoder.model.layers), ( |
| | f"Cannot keep {num_decoder_layers} layers. " |
| | f"Pre-trained model only has {len(self.decoder.layers)} layers." |
| | ) |
| | if keep_top_decoder_layers: |
| | self.decoder.model.layers = self.decoder.model.layers[ |
| | -num_decoder_layers: |
| | ] |
| | else: |
| | self.decoder.model.layers = self.decoder.model.layers[ |
| | :num_decoder_layers |
| | ] |
| | del self.decoder.model.embed_tokens |
| | |
| | self.encoder.model.embed_tokens.requires_grad_(True) |
| | unused_self_attn_params = ["o_proj", "q_norm", "q_proj"] |
| | unused_layernorm_params = ["input_layernorm", "post_attention_layernorm"] |
| | for unused_param in unused_self_attn_params: |
| | if hasattr(self.encoder.model.layers[-1].self_attn, unused_param): |
| | getattr( |
| | self.encoder.model.layers[-1].self_attn, unused_param |
| | ).requires_grad_(False) |
| | self.encoder.model.layers[-1].mlp.requires_grad_(False) |
| | self.encoder.model.norm.requires_grad_(False) |
| | for unused_param in unused_layernorm_params: |
| | if hasattr(self.encoder.model.layers[-1], unused_param): |
| | getattr(self.encoder.model.layers[-1], unused_param).requires_grad_( |
| | False |
| | ) |
| | |
| | |
| | if ( |
| | self.encoder.lm_head.weight.data_ptr() |
| | == self.encoder.model.embed_tokens.weight.data_ptr() |
| | ): |
| | self.decoder.lm_head = self.encoder.lm_head |
| | else: |
| | del self.encoder.lm_head |
| | if use_gradient_checkpointing: |
| | self.decoder.gradient_checkpointing_enable() |
| | self.max_length = max_length |
| |
|
| | def freeze_encoder(self): |
| | for p in self.encoder.model.parameters(): |
| | p.requires_grad = False |
| |
|
| | def unfreeze_encoder(self): |
| | for p in self.encoder.model.parameters(): |
| | p.requires_grad = True |
| |
|
| | |
| | def forward( |
| | self, |
| | |
| | input_ids: torch.LongTensor, |
| | attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[DynamicCache] = None, |
| | encoder_last_hidden_state: Optional[torch.FloatTensor] = None, |
| | |
| | encoder_input_ids: Optional[torch.LongTensor] = None, |
| | encoder_attention_mask: Optional[Union[torch.FloatTensor, BlockMask]] = None, |
| | encoder_position_ids: Optional[torch.LongTensor] = None, |
| | encoder_cache_position: Optional[torch.LongTensor] = None, |
| | encoder_past_key_values: Optional[DynamicCache] = None, |
| | |
| | fix_cache_length: bool = True, |
| | return_updated_cache: bool = False, |
| | **flash_attn_kwargs: Unpack[FlashAttentionKwargs], |
| | ) -> Union[CausalLMOutputWithPast, BaseModelOutputWithPast]: |
| | |
| | if encoder_input_ids is not None: |
| | if self.use_encoder_causal_mask: |
| | encoder_attention_mask = None |
| | if encoder_cache_position is None and encoder_position_ids is not None: |
| | encoder_cache_position = encoder_position_ids[0] |
| | past_key_values = self.encoder.model( |
| | input_ids=encoder_input_ids, |
| | attention_mask=encoder_attention_mask, |
| | position_ids=encoder_position_ids, |
| | use_cache=True, |
| | past_key_values=past_key_values, |
| | cache_position=encoder_cache_position, |
| | ).past_key_values |
| | if return_updated_cache: |
| | |
| | return BaseModelOutputWithPast( |
| | past_key_values=past_key_values, |
| | ) |
| |
|
| | |
| | decoder_hidden_states = self.encoder.model.embed_tokens(input_ids) |
| | if cache_position is None: |
| | if position_ids is not None: |
| | cache_position = position_ids[0] |
| | else: |
| | cache_position = torch.arange( |
| | decoder_hidden_states.shape[1], |
| | device=decoder_hidden_states.device, |
| | ) |
| | if position_ids is None: |
| | position_ids = cache_position.unsqueeze(0) |
| | decoder_position_embeddings = self.decoder.model.rotary_emb( |
| | decoder_hidden_states, position_ids |
| | ) |
| |
|
| | if hasattr(self.decoder.model, "_update_causal_mask"): |
| | |
| | attention_mask = self.decoder.model._update_causal_mask( |
| | attention_mask=attention_mask, |
| | input_tensor=decoder_hidden_states, |
| | cache_position=cache_position, |
| | past_key_values=past_key_values, |
| | output_attentions=False, |
| | ) |
| | for decoder_layer in self.decoder.model.layers: |
| | layer_idx = decoder_layer.self_attn.layer_idx |
| | if ( |
| | self.tie_encoder_decoder_weights |
| | and layer_idx not in self.decoder_layer_idxs |
| | ): |
| | continue |
| | |
| | |
| | if past_key_values is not None and len(past_key_values) > layer_idx: |
| | prev_cache_len = past_key_values[layer_idx][0].shape[-2] |
| | else: |
| | prev_cache_len = 0 |
| |
|
| | decoder_hidden_states = decoder_layer( |
| | hidden_states=decoder_hidden_states, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_value=past_key_values, |
| | output_attentions=False, |
| | use_cache=True, |
| | cache_position=position_ids[0], |
| | position_embeddings=decoder_position_embeddings, |
| | **flash_attn_kwargs, |
| | )[0] |
| |
|
| | if past_key_values is not None: |
| | |
| | |
| | past_key_values.key_cache[layer_idx] = past_key_values.key_cache[ |
| | layer_idx |
| | ][..., :prev_cache_len, :] |
| | past_key_values.value_cache[layer_idx] = past_key_values.value_cache[ |
| | layer_idx |
| | ][..., :prev_cache_len, :] |
| | decoder_hidden_states = self.decoder.model.norm(decoder_hidden_states) |
| | logits = self.decoder.lm_head(decoder_hidden_states) |
| | return CausalLMOutputWithPast( |
| | logits=logits, |
| | past_key_values=past_key_values, |
| | ) |
| |
|