Sequential-Hidden-Decoding-8B-n2 / modeling_qwen3_scale_seq.py
exlaw's picture
Upload folder using huggingface_hub
0d62c3c verified
"""Qwen3 with scaled sequence length via embedding replication.
Extends Qwen3Model/Qwen3ForCausalLM with scale_seq_times additional
embedding tables. During forward, the original token sequence of length L
is expanded to (1 + scale_seq_times) * L via interleaved multi-stream
embedding, then processed by the standard Qwen3 transformer body.
Architecture overview (n = 1 + scale_seq_times):
- n Embedding tables: E_0 (original), E_1, ..., E_{n-1} (new)
- Interleaved layout: [E_0(t1), E_1(t1), ..., E_0(t2), E_1(t2), ...]
- RoPE positions: 0, 1, 2, ..., n*L - 1 (continuous)
- Standard causal attention over all n*L positions
- Contraction: only the last stream's hidden_state per token goes through
lm_head (the stream with the richest context), matching v4dev behavior.
See: Scale_SeqLen_via_Embedding_Replication.md
"""
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import Qwen3ForCausalLM, Qwen3Model
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, can_return_tuple
from .configuration_qwen3_scale_seq import Qwen3ScaleSeqConfig
class Qwen3ScaleSeqModel(Qwen3Model):
"""Qwen3Model extended with multi-stream embedding for sequence scaling."""
config_class = Qwen3ScaleSeqConfig
def __init__(self, config: Qwen3ScaleSeqConfig):
super().__init__(config)
self.scale_seq_times = getattr(config, "scale_seq_times", 0)
if self.scale_seq_times > 0:
self.scale_seq_embed_tokens_list = nn.ModuleList(
[
nn.Embedding(
config.vocab_size,
config.hidden_size,
self.padding_idx,
)
for _ in range(self.scale_seq_times)
]
)
self.post_init()
def _expand_scale_seq(
self,
input_ids: torch.LongTensor,
hidden_states: torch.FloatTensor,
) -> torch.FloatTensor:
"""Expand hidden_states from (B, T, D) to (B, T * scale, D).
Layout per original token i:
[main_emb_i, scale_seq_1_emb_i, ..., scale_seq_N_emb_i]
Args:
input_ids: (batch, seq_len) original token ids.
hidden_states: (batch, seq_len, hidden) main embedding output.
Returns:
Expanded tensor of shape (batch, seq_len * scale, hidden).
"""
device = hidden_states.device
B, T, D = hidden_states.shape
# (B, T, D) -> (B, T, 1, D)
parts = [hidden_states.unsqueeze(2)]
for s in range(self.scale_seq_times):
emb_module = self.scale_seq_embed_tokens_list[s]
hs_s = emb_module(input_ids.to(emb_module.weight.device)).to(device)
parts.append(hs_s.unsqueeze(2)) # (B, T, 1, D)
# (B, T, scale, D) -> (B, T * scale, D)
expanded = torch.cat(parts, dim=2)
return expanded.reshape(B, T * (self.scale_seq_times + 1), D)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, BaseModelOutputWithPast]:
if (
self.scale_seq_times > 0
and input_ids is not None
and inputs_embeds is None
):
scale = self.scale_seq_times + 1
# Compute main embedding, then expand with scale_seq streams
inputs_embeds = self.embed_tokens(input_ids)
inputs_embeds = self._expand_scale_seq(input_ids, inputs_embeds)
B = inputs_embeds.shape[0]
T_expanded = inputs_embeds.shape[1]
# Recompute cache_position and position_ids in expanded space
past_seen_tokens = (
past_key_values.get_seq_length()
if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + T_expanded,
device=inputs_embeds.device,
)
position_ids = cache_position.unsqueeze(0).expand(B, -1)
# Expand attention_mask to match expanded sequence length
if attention_mask is not None:
attention_mask = attention_mask.repeat_interleave(scale, dim=1)
input_ids = None # avoid double embedding lookup in super().forward()
return super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
class Qwen3ScaleSeqForCausalLM(Qwen3ForCausalLM):
"""Qwen3ForCausalLM with multi-stream embedding for sequence scaling.
Contraction: after the transformer body produces (B, T*scale, D),
select only the last stream per token (the one with richest context)
before applying lm_head, producing (B, T, vocab_size).
"""
config_class = Qwen3ScaleSeqConfig
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config: Qwen3ScaleSeqConfig):
super().__init__(config)
# Replace the inner model with our scaled version
self.model = Qwen3ScaleSeqModel(config)
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values=None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> CausalLMOutputWithPast:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# ---- scale_seq contraction ----
# Contract expanded hidden_states (B, T*scale, D) back to logical
# token space (B, T, D) by selecting the last stream per token group
# (the stream with the richest context), matching v4dev behavior.
if self.model.scale_seq_times > 0:
scale = self.model.scale_seq_times + 1
hidden_states = hidden_states[:, scale - 1::scale, :]
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values if use_cache else None,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = ["Qwen3ScaleSeqModel", "Qwen3ScaleSeqForCausalLM"]