LongCat-Flash-Chat / modeling_longcat_flash.py
LongCat0830's picture
Upload 2 files
a39cdc2 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2025 Meituan
# This code is licensed under the MIT License, for details, see the ./LICENSE file.
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.generic import check_model_inputs
from .configuration_longcat_flash import LongcatFlashConfig
@use_kernel_forward_from_hub("RMSNorm")
class LongcatFlashRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LongcatFlashRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class LongcatFlashRotaryEmbedding(nn.Module):
def __init__(self, config: LongcatFlashConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class LongcatFlashMLP(nn.Module):
def __init__(self, config, hidden_size=None, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size if hidden_size is None else hidden_size
self.intermediate_size = config.ffn_hidden_size if intermediate_size is None else intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class LongcatFlashTopkRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.moe_topk
self.n_routed_experts = (
config.n_routed_experts
if config.zero_expert_num is None
else config.n_routed_experts + config.zero_expert_num
)
self.routed_scaling_factor = config.routed_scaling_factor
self.norm_topk_prob = config.norm_topk_prob
self.router_bias = config.router_bias
self.classifier = nn.Linear(config.hidden_size, self.n_routed_experts, bias=self.router_bias)
self.register_buffer("e_score_correction_bias", torch.zeros((self.n_routed_experts)))
@torch.no_grad()
def get_topk_indices(self, scores):
scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
return topk_indices
def forward(self, hidden_states):
hidden_states = hidden_states.view(-1, self.config.hidden_size)
router_logits = F.linear(hidden_states.type(torch.float32), self.classifier.weight.type(torch.float32))
scores = router_logits.softmax(dim=-1)
topk_indices = self.get_topk_indices(scores)
topk_weights = scores.gather(1, topk_indices)
if self.norm_topk_prob:
denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
topk_weights /= denominator
topk_weights = topk_weights * self.routed_scaling_factor
return topk_indices, topk_weights
class LongcatFlashMoE(nn.Module):
"""
moe module.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.experts = nn.ModuleList(
[
LongcatFlashMLP(config, intermediate_size=config.expert_ffn_hidden_size)
for _ in range(config.n_routed_experts)
]
)
self.router = LongcatFlashTopkRouter(config)
self.zero_expert_num = config.zero_expert_num
self.zero_expert_type = config.zero_expert_type
def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
total_experts = len(self.experts) if self.zero_expert_num is None else len(self.experts) + self.zero_expert_num
expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=total_experts)
expert_mask = expert_mask.permute(2, 0, 1)
for expert_idx in range(total_experts):
expert = self.experts[expert_idx] if expert_idx < len(self.experts) else None
mask = expert_mask[expert_idx]
token_indices, weight_indices = torch.where(mask)
if token_indices.numel() > 0:
expert_weights = topk_weights[token_indices, weight_indices]
expert_input = hidden_states[token_indices]
if self.zero_expert_num is None or expert_idx < len(self.experts):
expert_output = expert(expert_input)
elif self.zero_expert_type == "identity":
expert_output = expert_input
else:
raise ValueError("Unknown condition")
weighted_output = expert_output * expert_weights.unsqueeze(-1)
final_hidden_states.index_add_(0, token_indices, weighted_output)
return final_hidden_states.type(hidden_states.dtype)
def forward(self, hidden_states):
orig_shape = hidden_states.shape
topk_indices, topk_weights = self.router(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.moe(hidden_states, topk_indices, topk_weights).view(*orig_shape)
return hidden_states
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1, use_mla=False):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
if use_mla:
b, h, s, d = q.shape
q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
b, h, s, d = k.shape
k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class LongcatFlashMLA(nn.Module):
"""Modified from Deepseek MLA"""
def __init__(self, config: LongcatFlashConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.attention_dropout = config.attention_dropout
self.num_heads = config.num_attention_heads
self.rope_theta = config.rope_theta
self.q_lora_rank = config.q_lora_rank
self.qk_rope_head_dim = config.qk_rope_head_dim
self.kv_lora_rank = config.kv_lora_rank
self.v_head_dim = config.v_head_dim
self.qk_nope_head_dim = config.qk_nope_head_dim
self.qk_head_dim = config.qk_head_dim
self.is_causal = True
if self.q_lora_rank is None:
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False)
else:
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias)
self.q_a_layernorm = LongcatFlashRMSNorm(config.q_lora_rank)
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False)
self.kv_a_proj_with_mqa = nn.Linear(
config.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=config.attention_bias,
)
self.kv_a_layernorm = LongcatFlashRMSNorm(self.kv_lora_rank)
self.kv_b_proj = nn.Linear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
)
self.o_proj = nn.Linear(
self.num_heads * self.v_head_dim,
config.hidden_size,
bias=config.attention_bias,
)
if config.mla_scale_q_lora:
self.mla_scale_q_lora = (config.hidden_size / self.q_lora_rank) ** 0.5
if config.mla_scale_kv_lora:
self.mla_scale_kv_lora = (config.hidden_size / self.kv_lora_rank) ** 0.5
self.scaling = self.qk_head_dim ** (-0.5)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
batch_size, seq_length = hidden_states.shape[:-1]
query_shape = (batch_size, seq_length, -1, self.qk_head_dim)
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim)
q_states = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))).view(query_shape).transpose(1, 2)
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# apply q_lora scaling
if self.mla_scale_q_lora is not None:
q_pass = q_pass * self.mla_scale_q_lora
q_rot = q_rot * self.mla_scale_q_lora
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
k_pass = self.kv_a_layernorm(k_pass)
# apply kv_lora scaling
if self.mla_scale_kv_lora is not None:
k_pass = k_pass * self.mla_scale_kv_lora
k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2)
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim)
cos, sin = position_embeddings
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin, use_mla=True)
k_rot = k_rot.expand(*k_pass.shape[:-1], -1)
query_states = torch.cat((q_pass, q_rot), dim=-1)
key_states = torch.cat((k_pass, k_rot), dim=-1)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim])
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim:
attn_output = attn_output[:, :, :, : self.v_head_dim]
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
def create_attention_block(class_name, *args, **kwargs):
attention_mapping = {"MLA": LongcatFlashMLA}
chosen_class = attention_mapping.get(class_name)
if not chosen_class:
raise ValueError(f"No class found for name: {class_name}")
return chosen_class(*args, **kwargs)
class LongcatFlashDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: LongcatFlashConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.mlp = LongcatFlashMoE(config)
self_attn = []
mlps = []
input_layernorm = []
post_attention_layernorm = []
for i in range(2):
self_attn.append(
create_attention_block(config.attention_method, config=config, layer_idx=layer_idx * 2 + i)
)
mlps.append(LongcatFlashMLP(config))
input_layernorm.append(LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps))
post_attention_layernorm.append(LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps))
self.self_attn = nn.ModuleList(self_attn)
self.mlps = nn.ModuleList(mlps)
self.input_layernorm = nn.ModuleList(input_layernorm)
self.post_attention_layernorm = nn.ModuleList(post_attention_layernorm)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
for i in range(2):
residual = hidden_states
hidden_states = self.input_layernorm[i](hidden_states)
hidden_states, _ = self.self_attn[i](
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm[i](hidden_states)
if i == 0:
shortcut_mlp_output = self.mlp(hidden_states) # shortcut output (MoE output)
hidden_states = self.mlps[i](hidden_states)
hidden_states = residual + hidden_states
if i == 1:
hidden_states = hidden_states + shortcut_mlp_output
return hidden_states
@auto_docstring
class LongcatFlashPreTrainedModel(PreTrainedModel):
config: LongcatFlashConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LongcatFlashDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": LongcatFlashDecoderLayer,
"attentions": LongcatFlashMLA,
}
@auto_docstring
class LongcatFlashModel(LongcatFlashPreTrainedModel):
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
def __init__(self, config: LongcatFlashConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[LongcatFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = LongcatFlashRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LongcatFlashRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position: torch.Tensor = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(hidden_states, position_ids)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
@auto_docstring
class LongcatFlashForCausalLM(LongcatFlashPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
def __init__(self, config):
super().__init__(config)
self.model = LongcatFlashModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
Example:
```python
>>> from transformers import AutoTokenizer, LongcatFlashForCausalLM
>>> model = LongcatFlashForCausalLM.from_pretrained("meta-longcat_flash/LongcatFlash-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-longcat_flash/LongcatFlash-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["LongcatFlashPreTrainedModel", "LongcatFlashModel", "LongcatFlashForCausalLM"]