LongCat-Next-4bit / modeling_longcat_next.py
kernelpool's picture
Add files using upload-large-folder tool
74da6da verified
# -*- coding: utf-8 -*-
# Copyright (c) 2026 Meituan
# This code is licensed under the MIT License, for details, see the ./LICENSE file.
import os
from dataclasses import dataclass
from tqdm import tqdm
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache
from transformers.generation.configuration_utils import GenerationConfig
from transformers.generation.logits_process import LogitsProcessorList
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.generation.utils import GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput, GenerateNonBeamOutput
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.longcat_flash.modeling_longcat_flash import LongcatFlashForCausalLM
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from .configuration_longcat_next import LongcatNextConfig
from .modeling_longcat_ngram import LongcatFlashNgramModel, NgramCache
from .modular_longcat_next import CasualDepthTransformerHead
from .modular_longcat_next_audio import LongcatNextAudioTokenizer
from .modular_longcat_next_visual import LongcatNextVisualTokenizer
from .cosy24k_vocoder import Cosy24kVocoder
from .image_refiner import ImageRefinerContainer
from .refiner_modules import FlowMatchEulerDiscreteScheduler
logger = logging.get_logger(__name__)
@dataclass
class LongcatNextForCausalLMOutputWithPast(CausalLMOutputWithPast):
visual_loss: Optional[torch.FloatTensor] = None
visual_logits: Optional[torch.FloatTensor] = None
visual_ids: Optional[torch.LongTensor] = None
audio_loss: Optional[torch.FloatTensor] = None
audio_logits: Optional[torch.FloatTensor] = None
audio_ids: Optional[torch.LongTensor] = None
@dataclass
class LongcatNextForCausalLMGenerateDecoderOnlyOutput(GenerateDecoderOnlyOutput):
visual_ids: Optional[torch.LongTensor] = None
audio_ids: Optional[torch.LongTensor] = None
audio_text_ids: Optional[torch.LongTensor] = None
@dataclass
class LongcatNextForCausalLMGenerateEncoderDecoderOutput(GenerateEncoderDecoderOutput):
visual_ids: Optional[torch.LongTensor] = None
audio_ids: Optional[torch.LongTensor] = None
audio_text_ids: Optional[torch.LongTensor] = None
@dataclass
class LongcatNextForCausalLMGenerationStatus:
mode: str = "text"
current_image_token_num: int = -1
audio_parallel_decoding: bool = False
is_audio_text_end: bool = False
is_audio_start: bool = False
last_step_mode: str = None
def __init__(self, visual_generation_config, audio_generation_config):
self.visual_generation_config = visual_generation_config
self.h = self.visual_generation_config.custom_params["token_h"]
self.w = self.visual_generation_config.custom_params["token_w"]
self.anyres_prefix = self.visual_generation_config.custom_params["anyres_prefix"].format(h=self.h, w=self.w)
self.audio_generation_config = audio_generation_config
self.audio_parallel_decoding = audio_generation_config.audio_parallel_decoding
def switch_to(self, modal):
assert modal in ["text", "visual", "audio"]
self.mode = modal
self.current_image_token_num = 0 if modal == "visual" else -1
self.is_audio_text_end = False
self.is_audio_start = False
@property
def is_img_newline(self):
return ((self.current_image_token_num + 1) % (self.w + 1)) == 0 and not self.is_img_end
@property
def is_img_end(self):
return (self.current_image_token_num + 1) / (self.w + 1) == self.h
class LongcatNextModel(LongcatFlashNgramModel):
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
config_class = LongcatNextConfig
def __init__(self, config):
super().__init__(config)
self.visual_tokenizer = LongcatNextVisualTokenizer(config)
self.audio_tokenizer = LongcatNextAudioTokenizer(config)
self._init_multimodal_constants(config)
self.post_init()
def _init_multimodal_constants(self, config):
name2id_dict = {
"image_newline_token_id": self.config.visual_config.image_newline_token_id,
"image_end_token_id": self.config.visual_config.image_end_token_id,
"image_pad_token_id": self.config.visual_config.image_pad_token_id,
"audiotext_start_token_id": config.audio_config.audiotext_start_token_id,
"audiotext_pad_token_id": self.config.audio_config.audiotext_pad_token_id,
"audiogen_end_token_id": config.audio_config.audiogen_end_token_id,
"audio_pad_token_id": self.config.audio_config.audio_pad_token_id,
}
for k, v in name2id_dict.items():
self.register_buffer(k, torch.tensor([v], dtype=torch.long), persistent=False)
visual_offset_list = [config.visual_offset] + config.visual_config.vq_config.codebook_sizes[:-1]
visual_offset_vals = torch.cumsum(torch.tensor(visual_offset_list, dtype=torch.long), dim=0)
self.register_buffer("visual_offset_vals", visual_offset_vals, persistent=False)
audio_offset_list = [config.audio_offset] + config.audio_config.vq_config.codebook_sizes[:-1]
audio_offset_vals = torch.cumsum(torch.tensor(audio_offset_list, dtype=torch.long), dim=0)
self.register_buffer("audio_offset_vals", audio_offset_vals, persistent=False)
print(f"{self.visual_offset_vals=}")
print(f"{self.audio_offset_vals=}")
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
visual_inputs=None,
visual_ids=None,
audio_inputs=None,
audio_ids=None,
audio_text_ids=None,
multimodal_generation_status=None,
**kwargs
) -> BaseModelOutputWithPast:
if input_ids is None:
raise ValueError("You must specify input_ids")
# Extract N-gram context if available
ngram_context = None
if isinstance(past_key_values, NgramCache) and past_key_values.ngram_context is not None:
ngram_context = past_key_values.ngram_context
# assert input_ids.size(0) == 1, "only support bs=1 for now" # but when bs=2, idx=1 is for uncond_image_generation
special_visual_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask = self.get_placeholder_mask(input_ids[:1]) # seq-dim
if inputs_embeds is None:
input_ids[:, special_visual_mask | special_audio_mask | special_audio_text_pad_mask | special_audio_text_start_mask] = 0
filled_text_pad_mask = torch.ones_like(special_audio_mask)
audio_text_position_mask = (special_audio_text_pad_mask | special_audio_text_start_mask | special_audio_mask)
if audio_text_ids is not None and audio_text_ids.size(1) > 0 and audio_text_position_mask.sum() > 0:
filled_text = audio_text_ids[:, -audio_text_position_mask.sum():]
filled_text_pad_mask = (filled_text==self.config.audio_config.audiotext_pad_token_id)[0]
input_ids[:, audio_text_position_mask] = filled_text
input_ids[input_ids == self.config.audio_config.audiotext_pad_token_id] = 0
inputs_embeds = self.ngram_embeddings(input_ids, ngram_context=ngram_context)
inputs_embeds[:, (special_visual_mask | (special_audio_mask & filled_text_pad_mask))] = 0
if special_audio_text_start_mask.sum() > 0:
audio_text_start_embedding = self.embed_tokens(self.audiotext_start_token_id)
if multimodal_generation_status.last_step_mode is None: # prefill
inputs_embeds[:1, special_audio_text_start_mask] += audio_text_start_embedding
else:
inputs_embeds[:, special_audio_text_start_mask] += audio_text_start_embedding
if visual_inputs is not None:
visual_ids = self.get_visual_ids(**visual_inputs) # [<bs=1>*seq, lev]
if visual_ids is not None and special_visual_mask.sum() > 0:
visual_embeddings = self.get_visual_embeddings(visual_ids[-special_visual_mask.sum():]) # -> [seq, dim]
if multimodal_generation_status.last_step_mode is None: # prefill
inputs_embeds[:1, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
else:
inputs_embeds[:, special_visual_mask] = visual_embeddings.to(inputs_embeds.device)
if audio_inputs is not None:
audio_ids = self.get_audio_ids(**audio_inputs) # -> [<bs=1>*seq, lev]
if audio_ids is not None and special_audio_mask.sum() > 0:
audio_embeddings = self.get_audio_embeddings(audio_ids[-special_audio_mask.sum():]) # -> [seq, dim]
if multimodal_generation_status.last_step_mode is None: # prefill
inputs_embeds[:1, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
else:
inputs_embeds[:, special_audio_mask] += audio_embeddings.to(inputs_embeds.device)
# Initialize NgramCache if needed
if use_cache and past_key_values is None:
past_key_values = NgramCache(config=self.config)
# Update N-gram context
if use_cache and isinstance(past_key_values, NgramCache):
past_key_values.update_ngram_context(input_ids)
return super().forward(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
use_cache=use_cache,
**kwargs
)
def get_visual_ids(self, pixel_values, visual_grid_thw, offset=True):
visual_ids = self.visual_tokenizer.encode(pixel_values, visual_grid_thw)
if offset:
visual_ids += self.visual_offset_vals.to(visual_ids.device)
return visual_ids
def get_audio_ids(self, audio, encoder_length, bridge_length, offset=True):
audio_ids = self.audio_tokenizer.encode(audio, encoder_length, bridge_length)
if offset:
audio_ids += self.audio_offset_vals.to(audio_ids.device)
return audio_ids
@torch.no_grad()
def decode_visual_ids_and_save(
self,
visual_ids,
save_prefix,
token_h,
token_w,
**kwargs,
):
visual_ids -= self.visual_offset_vals.to(visual_ids.device)
if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
save_prefix = f"./{save_prefix}"
os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
return self.visual_tokenizer.lazy_decode_and_save(visual_ids, token_h, token_w, f"{save_prefix}_{0}.png")
@torch.no_grad()
def decode_audio_ids_and_save(
self,
audio_ids,
save_prefix,
sampling_rate,
wave_concat_overlap,
**kwargs,
):
audio_ids -= self.audio_offset_vals.to(audio_ids.device)
if not (save_prefix.startswith("./") or save_prefix.startswith("/")):
save_prefix = f"./{save_prefix}"
os.makedirs(os.path.dirname(save_prefix), exist_ok=True)
save_path = f"{save_prefix}_{0}.wav"
self.audio_tokenizer.lazy_decode_and_save(audio_ids, sampling_rate, wave_concat_overlap, save_path)
return [save_path]
def get_visual_embeddings(self, visual_ids):
visual_embeddings = self.embed_tokens(visual_ids).sum(dim=1) # [seq, lev] -> [seq, lev, dim] -> [seq, dim]
visual_embeddings = self.visual_tokenizer.visual_embedding_layer(visual_embeddings)
return visual_embeddings
def get_audio_embeddings(self, audio_ids):
audio_embeddings = self.embed_tokens(audio_ids).sum(dim=1)
return audio_embeddings
def get_placeholder_mask(self, input_ids: torch.LongTensor):
special_image_mask = (input_ids == self.config.visual_config.image_pad_token_id).squeeze(0)
special_audio_mask = (input_ids == self.config.audio_config.audio_pad_token_id).squeeze(0)
special_audio_text_start_mask = (input_ids == self.config.audio_config.audiotext_start_token_id).squeeze(0)
special_audio_text_pad_mask = (input_ids == self.config.audio_config.audiotext_pad_token_id).squeeze(0)
return special_image_mask, special_audio_mask, special_audio_text_start_mask, special_audio_text_pad_mask
class LongcatNextForCausalLM(LongcatFlashForCausalLM):
_keys_to_ignore_on_load_unexpected = [r"model\.mtp.*"]
_no_split_modules = [
"LongcatFlashDecoderLayer",
"CasualDepthTransformerHead",
]
config_class = LongcatNextConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.model = LongcatNextModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.text_vocab_plus_multimodal_special_token_size, bias=False)
self.visual_head = CasualDepthTransformerHead(
hidden_size=config.hidden_size,
codebook_sizes=config.visual_config.vq_config.codebook_sizes,
transformer_layer_num=config.visual_config.image_head_transformer_layers,
transformer_dim=config.visual_config.image_head_transformer_dims,
transformer_ffn_scale=config.visual_config.image_head_transformer_ffn_scale,
)
self.audio_head = CasualDepthTransformerHead(
hidden_size=config.hidden_size,
codebook_sizes=config.audio_config.vq_config.codebook_sizes,
transformer_layer_num=config.audio_config.audio_head_transformer_layers,
transformer_dim=config.audio_config.audio_head_transformer_dims,
transformer_ffn_scale=config.audio_config.audio_head_transformer_ffn_scale,
)
self.post_init()
@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
visual_inputs=None,
visual_ids=None,
audio_inputs=None,
audio_ids=None,
audio_text_ids=None,
multimodal_generation_status: LongcatNextForCausalLMGenerationStatus = None,
visual_generation_config: GenerationConfig = None,
audio_generation_config: GenerationConfig = None,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
r"""
visual_inputs (`BatchFeature`, *optional*):
Visual inputs returned by the processor, containing pixel values and grid metadata for image encoding.
visual_ids (`torch.LongTensor` of shape `(num_visual_tokens, num_codebooks)`, *optional*):
Quantized visual token ids from the visual tokenizer, used to build visual embeddings during generation.
audio_inputs (`BatchFeature`, *optional*):
Audio inputs returned by the processor, containing mel-spectrogram features and length metadata.
audio_ids (`torch.LongTensor` of shape `(num_audio_tokens, num_codebooks)`, *optional*):
Quantized audio token ids from the audio tokenizer, used to build audio embeddings during generation.
audio_text_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Token ids for the audio text transcript generated alongside audio tokens.
multimodal_generation_status (`LongcatNextForCausalLMGenerationStatus`, *optional*):
Stateful object tracking the current multimodal generation mode (text / visual / audio) and
associated counters used to route logits to the correct head during auto-regressive decoding.
visual_generation_config (`GenerationConfig`, *optional*):
Generation configuration for the visual head, controlling sampling parameters such as
`temperature`, `top_k`, `top_p`, and custom parameters like `cfg_scale` and `anyres_config`.
audio_generation_config (`GenerationConfig`, *optional*):
Generation configuration for the audio head, controlling sampling parameters such as
`temperature`, `top_k`, `top_p`, `repetition_penalty`, and `audio_parallel_decoding`.
"""
if multimodal_generation_status.mode == "visual" and visual_generation_config.custom_params["cfg_scale"] != 1.0 and input_ids.size(0) == 1:
input_ids = input_ids.repeat((2, 1))
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
visual_inputs=visual_inputs,
visual_ids=visual_ids,
audio_inputs=audio_inputs,
audio_ids=audio_ids,
audio_text_ids=audio_text_ids,
multimodal_generation_status=multimodal_generation_status,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
slice_hidden_states = hidden_states[:, slice_indices, :]
loss, logits = None, None
if multimodal_generation_status.mode == "visual" and \
(not multimodal_generation_status.is_img_newline) and (not multimodal_generation_status.is_img_end):
visual_ids = self.get_multimodal_logits_and_ids(
self.visual_head,
visual_ids,
slice_hidden_states,
self.model.embed_tokens,
self.config.visual_config.vq_config.codebook_sizes,
self.model.visual_offset_vals,
visual_generation_config,
)
else:
logits = self.lm_head(slice_hidden_states)
if multimodal_generation_status.mode == "audio" and multimodal_generation_status.is_audio_start:
audio_ids = self.get_multimodal_logits_and_ids(
self.audio_head,
audio_ids,
slice_hidden_states,
self.model.embed_tokens,
self.config.audio_config.vq_config.codebook_sizes,
self.model.audio_offset_vals,
audio_generation_config,
)
return LongcatNextForCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
visual_ids=visual_ids,
audio_ids=audio_ids,
)
def get_multimodal_logits_and_ids(
self,
head_model,
multimodal_ids,
hidden_states,
multimodal_embedding_layer,
codebook_sizes,
offset_vals,
multimodal_generation_config,
):
next_token_ids = torch.zeros(hidden_states.size(0), len(codebook_sizes), dtype=torch.long, device=hidden_states.device)
multimodal_embedding_layer = multimodal_embedding_layer.to(hidden_states.device)
for level, _ in enumerate(codebook_sizes):
logits = head_model(hidden_states, next_token_ids, multimodal_embedding_layer, level) # -> (bs, 1, dim)
next_token_id = self.inner_sample(logits, multimodal_ids[None, :, level]-offset_vals[level], multimodal_generation_config) # (bs, 1)
next_token_id += offset_vals[level]
next_token_ids[:, level] = next_token_id
return next_token_ids[:1]
def inner_sample(
self,
next_token_logits: torch.Tensor,
multimodal_ids: torch.LongTensor,
generation_config: GenerationConfig,
) -> torch.Tensor:
logits_processor = self._get_logits_processor(generation_config)
if "cfg_scale" in generation_config.custom_params and generation_config.custom_params["cfg_scale"] != 1.0:
cond_logits, uncond_logits = next_token_logits.chunk(2, dim=0)
next_token_logits = generation_config.custom_params["cfg_scale"] * (cond_logits - uncond_logits) + uncond_logits
next_token_scores = logits_processor(multimodal_ids, next_token_logits.to(multimodal_ids.device))
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
return next_tokens
@torch.no_grad()
def generate(self, inputs=None, **kwargs):
"""Override to ensure NgramCache is used."""
if "past_key_values" not in kwargs or kwargs["past_key_values"] is None:
kwargs["past_key_values"] = NgramCache(config=self.config)
return super().generate(
inputs=inputs,
**kwargs,
)
def prepare_inputs_for_generation(
self,
input_ids,
visual_ids,
audio_ids,
audio_text_ids,
multimodal_generation_status,
generation_config,
attention_mask,
cache_position,
**kwargs,
):
extra_new_tokens = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
if visual_ids is None:
visual_ids = torch.empty(0, len(self.config.visual_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
if audio_ids is None:
audio_ids = torch.empty(0, len(self.config.audio_config.vq_config.codebook_sizes), dtype=torch.long, device=input_ids.device)
if audio_text_ids is None:
audio_text_ids = torch.empty(input_ids.size(0), 0, dtype=torch.long, device=input_ids.device)
def insert_ids(new_ids, _input_ids, _attention_mask, _cache_position, position=0):
if position < 0:
parts = [_input_ids[:, :position], new_ids, _input_ids[:, position:]]
else:
parts = [_input_ids, new_ids]
_input_ids = torch.cat(parts, dim=1)
insert_len = new_ids.size(1)
_attention_mask = F.pad(_attention_mask, (0, insert_len), value=1)
insert_position = _cache_position[-1] + 1 + torch.arange(insert_len, device=_cache_position.device)
_cache_position = torch.cat([_cache_position, insert_position])
return _input_ids, _attention_mask, _cache_position
# multimodal generation status change
if cache_position[0] != 0:
multimodal_generation_status.last_step_mode = multimodal_generation_status.mode
if multimodal_generation_status.mode == "visual":
multimodal_generation_status.current_image_token_num += 1
if (input_ids[:, -1] == self.config.visual_config.image_start_token_id).all():
multimodal_generation_status.switch_to("visual")
anyres_prefix_ids = self.text_tokenizer.encode(multimodal_generation_status.anyres_prefix, return_tensors="pt")
anyres_prefix_ids = anyres_prefix_ids.to(input_ids.device)
extra_new_tokens = torch.cat([extra_new_tokens, anyres_prefix_ids], dim=1)
input_ids, attention_mask, cache_position = insert_ids(anyres_prefix_ids, input_ids, attention_mask, cache_position, position=-1)
if input_ids.size(0) == 1: # cfg, change bs=1 -> 2
input_ids = input_ids.repeat((2, input_ids.size(1)))
input_ids[1, :-(anyres_prefix_ids.size(-1)+1)] = 0
print(f"change to cfg, input_ids: {input_ids}")
attention_mask = attention_mask.repeat((2, attention_mask.size(1)))
elif (input_ids[:, -1] == self.config.audio_config.audiogen_start_token_id).all():
multimodal_generation_status.switch_to("audio")
elif (input_ids[:, -1] == self.config.audio_config.audiotext_start_token_id).all():
multimodal_generation_status.is_audio_start = True
elif ((input_ids[:, -1] == self.config.visual_config.image_end_token_id) | (input_ids[:, -1] == self.config.audio_config.audiogen_end_token_id)).all():
multimodal_generation_status.switch_to("text")
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
visual_ids=visual_ids,
audio_ids=audio_ids,
audio_text_ids=audio_text_ids,
attention_mask=attention_mask,
cache_position=cache_position,
**kwargs,
)
if model_inputs["cache_position"][0] != 0:
model_inputs["visual_inputs"] = None
model_inputs["audio_inputs"] = None
return model_inputs, multimodal_generation_status, extra_new_tokens
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: LogitsProcessorList,
stopping_criteria: StoppingCriteriaList,
generation_config: GenerationConfig,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
visual_ids=None,
audio_ids=None,
audio_text_ids=None,
**model_kwargs,
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
r"""
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
Parameters:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
The sequence used as a prompt for the generation.
logits_processor (`LogitsProcessorList`):
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
used to modify the prediction scores of the language modeling head applied at each generation step.
stopping_criteria (`StoppingCriteriaList`):
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
used to tell if the generation loop should stop.
generation_config ([`~generation.GenerationConfig`]):
The generation configuration to be used as parametrization of the decoding method.
synced_gpus (`bool`):
Whether to continue running the while loop until max_length (needed to avoid deadlocking with
`FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3).
streamer (`BaseStreamer`, *optional*):
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
model_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
an encoder-decoder model the kwargs should include `encoder_outputs`.
Return:
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
`model.config.is_encoder_decoder=True`.
"""
# init values
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
do_sample = generation_config.do_sample
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2":
# only raise warning if the user passed an explicit compile-config
if generation_config.compile_config is not None and generation_config.compile_config.fullgraph:
logger.warning_once(
"When using Flash Attention 2 and a static cache, you cannot use the option `CompileConfig(fullgraph=True)` as "
"FA2 introduces graph breaks. We overrode the option with `fullgraph=False`."
)
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)
if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True
visual_generation_config = GenerationConfig(**generation_config.visual_generation_config)
audio_generation_config = GenerationConfig(**generation_config.audio_generation_config)
multimodal_generation_status = LongcatNextForCausalLMGenerationStatus(visual_generation_config, audio_generation_config)
pbar = tqdm(iter(int, 1), desc="Generating", unit="tok")
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# prepare model inputs
model_inputs, multimodal_generation_status, extra_new_tokens = self.prepare_inputs_for_generation(
input_ids,
visual_ids,
audio_ids,
audio_text_ids,
multimodal_generation_status,
generation_config,
**model_kwargs,
)
if extra_new_tokens.size(1) > 0:
input_ids = torch.cat([input_ids[:, :-1], extra_new_tokens, input_ids[:, -1:]], dim=1)
model_kwargs["attention_mask"] = model_inputs["attention_mask"]
model_kwargs["cache_position"] = model_inputs["cache_position"]
if multimodal_generation_status.mode == "text" and multimodal_generation_status.last_step_mode == "visual":
next_tokens = generation_config._eos_token_tensor
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
break
visual_ids = model_inputs["visual_ids"]
audio_ids = model_inputs["audio_ids"]
audio_text_ids = model_inputs["audio_text_ids"]
if is_prefill:
outputs = self(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True, multimodal_generation_status=multimodal_generation_status, visual_generation_config=visual_generation_config, audio_generation_config=audio_generation_config)
# synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
num_new_tokens=1,
)
if synced_gpus and this_peer_finished:
continue
# multimodal generation
if multimodal_generation_status.mode == "text" or \
(multimodal_generation_status.mode == "audio" and not multimodal_generation_status.is_audio_text_end):
# Copy is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device)
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# token selection
if do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)
# audio_text_ids done
if multimodal_generation_status.mode == "audio" and (next_tokens == self.config.audio_config.audiotext_pad_token_id).all():
multimodal_generation_status.is_audio_text_end = True
elif multimodal_generation_status.mode == "visual":
if multimodal_generation_status.is_img_end:
next_tokens = self.model.image_end_token_id.to(input_ids.device)
elif multimodal_generation_status.is_img_newline:
next_tokens = self.model.image_newline_token_id.to(input_ids.device)
else:
visual_ids = torch.cat([visual_ids, outputs.visual_ids], dim=0) # [seq, lev]
next_tokens = self.model.image_pad_token_id.to(input_ids.device)
else: # mode == "audio" and multimodal_generation_status.is_audio_text_end
next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
if multimodal_generation_status.mode == "audio":
# audio_text_ids update
audio_text_next_tokens = self.model.audiotext_pad_token_id.to(input_ids.device)
if not multimodal_generation_status.is_audio_text_end:
audio_text_next_tokens, next_tokens = next_tokens, audio_text_next_tokens
audio_text_ids = torch.cat((audio_text_ids, audio_text_next_tokens[:, None]), dim=1)
# audio_ids update
if multimodal_generation_status.is_audio_start:
if outputs.audio_ids[-1, 0] == (self.model.audio_offset_vals[1]): # offset + (level_1_len)
next_tokens = self.model.audiogen_end_token_id.to(input_ids.device)
else:
next_tokens = self.model.audio_pad_token_id.to(input_ids.device)
audio_ids = torch.cat([audio_ids, outputs.audio_ids], dim=0)
elif (multimodal_generation_status.audio_parallel_decoding) or \
(not multimodal_generation_status.audio_parallel_decoding and multimodal_generation_status.is_audio_text_end):
next_tokens = self.model.audiotext_start_token_id.to(input_ids.device)
# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
# TODO: streaming mm ids
if streamer is not None:
streamer.put(next_tokens.cpu())
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
pbar.update(1)
pbar.set_postfix({
"recent_5toks": f"{input_ids[:, -5:].tolist()}",
})
pbar.close()
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return LongcatNextForCausalLMGenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
visual_ids=visual_ids,
audio_ids=audio_ids,
audio_text_ids=audio_text_ids,
)
else:
return LongcatNextForCausalLMGenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
visual_ids=visual_ids,
audio_ids=audio_ids,
audio_text_ids=audio_text_ids,
)
else:
return input_ids, visual_ids, audio_ids, audio_text_ids
__all__ = ["LongcatNextModel", "LongcatNextForCausalLM"]