chen-yingfa commited on
Commit
af8fa42
·
verified ·
1 Parent(s): ababf6f

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  ckpt_500/tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  ckpt_500/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
added_tokens.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "</think>": 151668,
3
+ "</tool_call>": 151658,
4
+ "</tool_response>": 151666,
5
+ "<think>": 151667,
6
+ "<tool_call>": 151657,
7
+ "<tool_response>": 151665,
8
+ "<|box_end|>": 151649,
9
+ "<|box_start|>": 151648,
10
+ "<|endoftext|>": 151643,
11
+ "<|file_sep|>": 151664,
12
+ "<|fim_middle|>": 151660,
13
+ "<|fim_pad|>": 151662,
14
+ "<|fim_prefix|>": 151659,
15
+ "<|fim_suffix|>": 151661,
16
+ "<|im_end|>": 151645,
17
+ "<|im_start|>": 151644,
18
+ "<|image_pad|>": 151655,
19
+ "<|object_ref_end|>": 151647,
20
+ "<|object_ref_start|>": 151646,
21
+ "<|quad_end|>": 151651,
22
+ "<|quad_start|>": 151650,
23
+ "<|repo_name|>": 151663,
24
+ "<|video_pad|>": 151656,
25
+ "<|vision_end|>": 151653,
26
+ "<|vision_pad|>": 151654,
27
+ "<|vision_start|>": 151652
28
+ }
cache.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, List, Dict, Any
2
+
3
+ import torch
4
+
5
+ from transformers.cache_utils import Cache
6
+
7
+
8
+ class HybridCache(Cache):
9
+ """
10
+ A cache for hybrid contextual states. Some layers are attention, some layers are RNNs.
11
+ """
12
+
13
+ def __init__(
14
+ self,
15
+ seen_tokens: int = 0,
16
+ ):
17
+
18
+ self.states: List[Dict[str, Any]] = []
19
+
20
+ self._seen_tokens = seen_tokens # Used in `generate` to keep tally of how many tokens the cache has seen
21
+
22
+ @property
23
+ def is_compileable(self) -> bool:
24
+ return False
25
+
26
+ @property
27
+ def seen_tokens(self) -> int:
28
+ return self._seen_tokens
29
+
30
+ def __getitem__(self, layer_idx: int) -> Dict[str, Any]:
31
+ if layer_idx < len(self):
32
+ return self.states[layer_idx]
33
+ else:
34
+ raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
35
+
36
+ def __iter__(self):
37
+ for state in self.states:
38
+ yield state
39
+
40
+ def __len__(self):
41
+ return len(self.states)
42
+
43
+ def update(
44
+ self,
45
+ recurrent_state: torch.Tensor | None = None,
46
+ attn_state: Tuple[torch.Tensor, torch.Tensor] | None = None,
47
+ conv_state: Tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None] | None = None,
48
+ ffn_state: torch.Tensor | None = None,
49
+ layer_idx: int = 0,
50
+ offset: Optional[int] = 1,
51
+ cache_kwargs: Optional[Dict[str, Any]] = None,
52
+ ) -> Dict[str, Any]:
53
+ """
54
+ Updates the cache with the new `recurrent_state`/`attn_state`/`conv_state` for the layer `layer_idx`.
55
+
56
+ Args:
57
+ recurrent_state (`torch.Tensor`, `optional`):
58
+ The new recurrent state to cache.
59
+ attn_state (`Tuple[torch.Tensor, torch.Tensor]`, `optional`):
60
+ The new attention key/value states to cache.
61
+ conv_state (`Tuple[torch.Tensor]`, `optional`):
62
+ The new convolution state to cache.
63
+ layer_idx (`int`, defaults to 0):
64
+ The index of the layer to cache the states for.
65
+ offset (`int`, `optional`, defaults to 1):
66
+ The number of new tokens being processed.
67
+ cache_kwargs (`Dict[str, Any]`, `optional`):
68
+ Additional arguments for the cache subclass.
69
+
70
+ Return:
71
+ Dictionary of the updated state.
72
+ """
73
+
74
+ # Update the number of seen tokens
75
+ if layer_idx == 0:
76
+ self._seen_tokens += offset
77
+
78
+ if attn_state is not None:
79
+ input_size = attn_state[0].shape[-2]
80
+ if cache_kwargs is not None:
81
+ window_size = cache_kwargs.get('window_size', None)
82
+ else:
83
+ window_size = None
84
+ if not isinstance(attn_state, Tuple) or len(attn_state) != 2:
85
+ raise ValueError("`attn_state` must be a tuple of two tensors for key/value states")
86
+ if len(self.states) <= layer_idx:
87
+ if attn_state is not None:
88
+ if window_size is not None and input_size > window_size:
89
+ attn_state = (attn_state[0][..., -window_size:, :].contiguous(),
90
+ attn_state[1][..., -window_size:, :].contiguous())
91
+ state = dict(
92
+ recurrent_state=recurrent_state,
93
+ attn_state=attn_state,
94
+ conv_state=conv_state,
95
+ ffn_state=ffn_state
96
+ )
97
+ self.states.append(state)
98
+ else:
99
+ state = self.states[layer_idx]
100
+ if recurrent_state is not None:
101
+ state['recurrent_state'] = recurrent_state
102
+ if attn_state is not None:
103
+ key_state, value_state = state['attn_state']
104
+ if window_size is not None and key_state.shape[-2] == window_size:
105
+ # DO NOT allocate new memory if the cache is full
106
+ # roll the key/value states to the left by `input_size`
107
+ key_state = key_state.roll(-input_size, -2)
108
+ value_state = value_state.roll(-input_size, -2)
109
+ # replace the last `input_size` tokens with the new key/value states
110
+ key_state[..., -input_size:, :] = attn_state[0]
111
+ value_state[..., -input_size:, :] = attn_state[1]
112
+ attn_state = (key_state, value_state)
113
+ else:
114
+ attn_state = (torch.cat([key_state, attn_state[0]], -2),
115
+ torch.cat([value_state, attn_state[1]], -2),)
116
+ state['attn_state'] = attn_state
117
+ if conv_state is not None:
118
+ state['conv_state'] = conv_state
119
+ if ffn_state is not None:
120
+ state['ffn_state'] = ffn_state
121
+
122
+ return state
123
+
124
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
125
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
126
+ if len(self.states) <= layer_idx:
127
+ return 0
128
+ return self._seen_tokens
129
+
130
+ def get_max_length(self) -> Optional[int]:
131
+ """Returns the maximum sequence length of the cached states. Cache does not have a maximum length."""
132
+ return None
133
+
134
+ def to_legacy_cache(self) -> Tuple:
135
+ return tuple(self.states)
136
+
137
+ @classmethod
138
+ def from_legacy_cache(
139
+ cls,
140
+ past_key_values: Optional[Tuple] = None,
141
+ seen_tokens: int = 0
142
+ ) -> "HybridCache":
143
+ """Converts a cache in the legacy cache format into an equivalent `Cache`."""
144
+
145
+ cache = cls(seen_tokens)
146
+ if past_key_values is not None:
147
+ for layer_idx in range(len(past_key_values)):
148
+ cache.states.append(past_key_values[layer_idx])
149
+ return cache
150
+
chat_template.jinja ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {{- messages[0].content + '\n\n' }}
5
+ {%- endif %}
6
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
7
+ {%- for tool in tools %}
8
+ {{- "\n" }}
9
+ {{- tool | tojson }}
10
+ {%- endfor %}
11
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
12
+ {%- else %}
13
+ {%- if messages[0].role == 'system' %}
14
+ {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
15
+ {%- endif %}
16
+ {%- endif %}
17
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
18
+ {%- for message in messages[::-1] %}
19
+ {%- set index = (messages|length - 1) - loop.index0 %}
20
+ {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
21
+ {%- set ns.multi_step_tool = false %}
22
+ {%- set ns.last_query_index = index %}
23
+ {%- endif %}
24
+ {%- endfor %}
25
+ {%- for message in messages %}
26
+ {%- if message.content is string %}
27
+ {%- set content = message.content %}
28
+ {%- else %}
29
+ {%- set content = '' %}
30
+ {%- endif %}
31
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
32
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
33
+ {%- elif message.role == "assistant" %}
34
+ {%- set reasoning_content = '' %}
35
+ {%- if message.reasoning_content is string %}
36
+ {%- set reasoning_content = message.reasoning_content %}
37
+ {%- else %}
38
+ {%- if '</think>' in content %}
39
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
40
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
41
+ {%- endif %}
42
+ {%- endif %}
43
+ {%- if loop.index0 > ns.last_query_index %}
44
+ {%- if loop.last or (not loop.last and reasoning_content) %}
45
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
46
+ {%- else %}
47
+ {{- '<|im_start|>' + message.role + '\n' + content }}
48
+ {%- endif %}
49
+ {%- else %}
50
+ {{- '<|im_start|>' + message.role + '\n' + content }}
51
+ {%- endif %}
52
+ {%- if message.tool_calls %}
53
+ {%- for tool_call in message.tool_calls %}
54
+ {%- if (loop.first and content) or (not loop.first) %}
55
+ {{- '\n' }}
56
+ {%- endif %}
57
+ {%- if tool_call.function %}
58
+ {%- set tool_call = tool_call.function %}
59
+ {%- endif %}
60
+ {{- '<tool_call>\n{"name": "' }}
61
+ {{- tool_call.name }}
62
+ {{- '", "arguments": ' }}
63
+ {%- if tool_call.arguments is string %}
64
+ {{- tool_call.arguments }}
65
+ {%- else %}
66
+ {{- tool_call.arguments | tojson }}
67
+ {%- endif %}
68
+ {{- '}\n</tool_call>' }}
69
+ {%- endfor %}
70
+ {%- endif %}
71
+ {{- '<|im_end|>\n' }}
72
+ {%- elif message.role == "tool" %}
73
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
74
+ {{- '<|im_start|>user' }}
75
+ {%- endif %}
76
+ {{- '\n<tool_response>\n' }}
77
+ {{- content }}
78
+ {{- '\n</tool_response>' }}
79
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
80
+ {{- '<|im_end|>\n' }}
81
+ {%- endif %}
82
+ {%- endif %}
83
+ {%- endfor %}
84
+ {%- if add_generation_prompt %}
85
+ {{- '<|im_start|>assistant\n' }}
86
+ {%- if enable_thinking is defined and enable_thinking is false %}
87
+ {{- '<think>\n\n</think>\n\n' }}
88
+ {%- endif %}
89
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "HybridForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "attn_logits_scaling": "hype 500",
8
+ "attn_sqrtd": true,
9
+ "attn_use_output_gate": true,
10
+ "attn_use_rope": false,
11
+ "auto_map": {
12
+ "AutoConfig": "configuration_hybrid.HybridConfig",
13
+ "AutoModelForCausalLM": "modeling_hybrid.HybridForCausalLM"
14
+ },
15
+ "bos_token_id": 151643,
16
+ "dtype": "bfloat16",
17
+ "eos_token_id": 151645,
18
+ "expand_kv_proj": true,
19
+ "fused_ce_loss": true,
20
+ "gdn_activation": null,
21
+ "gdn_attn_mode": "chunk",
22
+ "gdn_expand_v": 1,
23
+ "gdn_fuse_cross_entropy": false,
24
+ "gdn_nh": 16,
25
+ "gdn_nkv": 16,
26
+ "gdn_use_gate": false,
27
+ "gdn_use_qk_norm": false,
28
+ "gdn_use_rope": false,
29
+ "gdn_use_short_conv": false,
30
+ "head_dim": 128,
31
+ "hidden_act": "silu",
32
+ "hidden_size": 2048,
33
+ "initializer_range": 0.02,
34
+ "intermediate_size": 6144,
35
+ "kda_head_dim": 128,
36
+ "kda_num_heads": 16,
37
+ "kda_use_conv": false,
38
+ "kda_use_qk_norm": true,
39
+ "kda_use_rope": false,
40
+ "lightning_conv_size": 4,
41
+ "lightning_head_dim": 128,
42
+ "lightning_nh": 16,
43
+ "lightning_nkv": 16,
44
+ "lightning_rope_scaling": null,
45
+ "lightning_scale": "1/sqrt(d)",
46
+ "lightning_use_output_gate": true,
47
+ "lightning_use_output_norm": true,
48
+ "lightning_use_qk_norm": true,
49
+ "lightning_use_rope": true,
50
+ "lightning_use_short_conv": false,
51
+ "loss_fn": "kl_div",
52
+ "mamba2_bias": false,
53
+ "mamba2_conv_kernel": 4,
54
+ "mamba2_expand_ratio": 1.0,
55
+ "mamba2_hidden_act": null,
56
+ "mamba2_n_groups": 1,
57
+ "max_position_embeddings": 40960,
58
+ "max_window_layers": 28,
59
+ "mixer_types": [
60
+ "lightning-attn",
61
+ "lightning-attn",
62
+ "attn",
63
+ "attn",
64
+ "lightning-attn",
65
+ "lightning-attn",
66
+ "attn",
67
+ "lightning-attn",
68
+ "attn",
69
+ "attn",
70
+ "lightning-attn",
71
+ "lightning-attn",
72
+ "lightning-attn",
73
+ "lightning-attn",
74
+ "lightning-attn",
75
+ "lightning-attn",
76
+ "lightning-attn",
77
+ "lightning-attn",
78
+ "lightning-attn",
79
+ "lightning-attn",
80
+ "lightning-attn",
81
+ "attn",
82
+ "lightning-attn",
83
+ "lightning-attn",
84
+ "lightning-attn",
85
+ "attn",
86
+ "lightning-attn",
87
+ "lightning-attn"
88
+ ],
89
+ "model_type": "hybrid",
90
+ "num_attention_heads": 16,
91
+ "num_hidden_layers": 28,
92
+ "num_key_value_heads": 8,
93
+ "rand_init": false,
94
+ "rms_norm_eps": 1e-06,
95
+ "rope_scaling": null,
96
+ "rope_theta": 1000000,
97
+ "shift_labels": true,
98
+ "sliding_window": null,
99
+ "tie_word_embeddings": true,
100
+ "transformers_version": "4.57.3",
101
+ "use_cache": true,
102
+ "use_rope": false,
103
+ "use_sliding_window": false,
104
+ "vocab_size": 151936
105
+ }
configuration_hybrid.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3 model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.modeling_rope_utils import rope_config_validation
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class HybridConfig(PretrainedConfig):
26
+ r"""
27
+ This is the configuration class to store the configuration of a [`Qwen3Model`]. It is used to instantiate a
28
+ Qwen3 model according to the specified arguments, defining the model architecture. Instantiating a configuration
29
+ with the defaults will yield a similar configuration to that of
30
+ Qwen3-8B [Qwen/Qwen3-8B](https://huggingface.co/Qwen/Qwen3-8B).
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 151936):
38
+ Vocabulary size of the Qwen3 model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`Qwen3Model`]
40
+ hidden_size (`int`, *optional*, defaults to 4096):
41
+ Dimension of the hidden representations.
42
+ intermediate_size (`int`, *optional*, defaults to 22016):
43
+ Dimension of the MLP representations.
44
+ num_hidden_layers (`int`, *optional*, defaults to 32):
45
+ Number of hidden layers in the Transformer encoder.
46
+ num_attention_heads (`int`, *optional*, defaults to 32):
47
+ Number of attention heads for each attention layer in the Transformer encoder.
48
+ num_key_value_heads (`int`, *optional*, defaults to 32):
49
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
50
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
51
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
52
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
53
+ by meanpooling all the original heads within that group. For more details checkout [this
54
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
55
+ head_dim (`int`, *optional*, defaults to 128):
56
+ The attention head dimension.
57
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
58
+ The non-linear activation function (function or string) in the decoder.
59
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
60
+ The maximum sequence length that this model might ever be used with.
61
+ initializer_range (`float`, *optional*, defaults to 0.02):
62
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
63
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
64
+ The epsilon used by the rms normalization layers.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
69
+ Whether the model's input and output word embeddings should be tied.
70
+ rope_theta (`float`, *optional*, defaults to 10000.0):
71
+ The base period of the RoPE embeddings.
72
+ rope_scaling (`Dict`, *optional*):
73
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
74
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
75
+ accordingly.
76
+ Expected contents:
77
+ `rope_type` (`str`):
78
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
79
+ 'llama3'], with 'default' being the original RoPE implementation.
80
+ `factor` (`float`, *optional*):
81
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
82
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
83
+ original maximum pre-trained length.
84
+ `original_max_position_embeddings` (`int`, *optional*):
85
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
86
+ pretraining.
87
+ `attention_factor` (`float`, *optional*):
88
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
89
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
90
+ `factor` field to infer the suggested value.
91
+ `beta_fast` (`float`, *optional*):
92
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
93
+ ramp function. If unspecified, it defaults to 32.
94
+ `beta_slow` (`float`, *optional*):
95
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
96
+ ramp function. If unspecified, it defaults to 1.
97
+ `short_factor` (`List[float]`, *optional*):
98
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
99
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
100
+ size divided by the number of attention heads divided by 2
101
+ `long_factor` (`List[float]`, *optional*):
102
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
103
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
104
+ size divided by the number of attention heads divided by 2
105
+ `low_freq_factor` (`float`, *optional*):
106
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
107
+ `high_freq_factor` (`float`, *optional*):
108
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
109
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
110
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
111
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
112
+ Whether to use sliding window attention.
113
+ sliding_window (`int`, *optional*, defaults to 4096):
114
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
115
+ max_window_layers (`int`, *optional*, defaults to 28):
116
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
117
+ attention_dropout (`float`, *optional*, defaults to 0.0):
118
+ The dropout ratio for the attention probabilities.
119
+
120
+ ```python
121
+ >>> from transformers import Qwen3Model, Qwen3Config
122
+
123
+ >>> # Initializing a Qwen3 style configuration
124
+ >>> configuration = Qwen3Config()
125
+
126
+ >>> # Initializing a model from the Qwen3-8B style configuration
127
+ >>> model = Qwen3Model(configuration)
128
+
129
+ >>> # Accessing the model configuration
130
+ >>> configuration = model.config
131
+ ```"""
132
+
133
+ model_type = "hybrid"
134
+ keys_to_ignore_at_inference = ["past_key_values"]
135
+
136
+ # Default tensor parallel plan for base model `Qwen3`
137
+ base_model_tp_plan = {
138
+ "layers.*.self_attn.q_proj": "colwise",
139
+ "layers.*.self_attn.k_proj": "colwise",
140
+ "layers.*.self_attn.v_proj": "colwise",
141
+ "layers.*.self_attn.o_proj": "rowwise",
142
+ "layers.*.mlp.gate_proj": "colwise",
143
+ "layers.*.mlp.up_proj": "colwise",
144
+ "layers.*.mlp.down_proj": "rowwise",
145
+ }
146
+ base_model_pp_plan = {
147
+ "embed_tokens": (["input_ids"], ["inputs_embeds"]),
148
+ "layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
149
+ "norm": (["hidden_states"], ["hidden_states"]),
150
+ }
151
+
152
+ def __init__(
153
+ self,
154
+ vocab_size=151936,
155
+ hidden_size=4096,
156
+ intermediate_size=22016,
157
+ num_hidden_layers=32,
158
+ num_attention_heads=32,
159
+ num_key_value_heads=32,
160
+ head_dim=128,
161
+ mixer_types: list[str] = [],
162
+ hidden_act="silu",
163
+ max_position_embeddings=32768,
164
+ initializer_range=0.02,
165
+ rms_norm_eps=1e-6,
166
+ use_cache=True,
167
+ tie_word_embeddings=False,
168
+ rope_theta=10000.0,
169
+ rope_scaling=None,
170
+ attention_bias=False,
171
+ use_sliding_window=False,
172
+ sliding_window=4096,
173
+ max_window_layers=28,
174
+ attention_dropout=0.0,
175
+ _attn_implementation: str = 'flash_attention_2',
176
+ # Gated DeltaNet
177
+ gdn_use_short_conv: bool = False,
178
+ gdn_use_gate: bool = False,
179
+ gdn_expand_v: int = 1,
180
+ gdn_attn_mode: str = 'chunk',
181
+ gdn_fuse_cross_entropy: bool = False,
182
+ gdn_activation: str | None = None,
183
+ gdn_nh: int | None = None,
184
+ gdn_nkv: int | None = None,
185
+ gdn_use_qk_norm: bool = False,
186
+ gdn_use_rope: bool = False,
187
+ # Mamba2
188
+ mamba2_n_groups: int = 1,
189
+ mamba2_expand_ratio: float = 1.0,
190
+ mamba2_conv_kernel: int = 4,
191
+ mamba2_bias: bool = False,
192
+ mamba2_hidden_act: str | None = None,
193
+ # Lightning attention
194
+ lightning_use_qk_norm: bool = False,
195
+ lightning_use_output_gate: bool = False,
196
+ lightning_use_output_norm: bool = False,
197
+ lightning_use_rope: bool = True,
198
+ lightning_rope_scaling: bool | None = None, # true: use the rope_scaling of the teacher model.
199
+ lightning_nh: int | None = None,
200
+ lightning_nkv: int | None = None,
201
+ lightning_head_dim: int | None = None,
202
+ lightning_scale: str = '1/sqrt(d)',
203
+ lightning_use_short_conv: bool = False,
204
+ lightning_conv_size: int = 4,
205
+ # Kimi Delta Attention
206
+ kda_head_dim: int | None = None,
207
+ kda_num_heads: int | None = None,
208
+ kda_use_conv: bool = False,
209
+ kda_use_qk_norm: bool = True,
210
+ kda_use_rope: bool = False,
211
+ # Other
212
+ expand_kv_proj: bool = False,
213
+ use_rope: bool = False,
214
+ attn_sqrtd: bool = True,
215
+ loss_fn: str = 'kl_div',
216
+ attn_use_rope: bool = True,
217
+ fused_ce_loss: bool = True,
218
+ shift_labels: bool = True,
219
+ attn_logits_scaling: None | str | float = None,
220
+ attn_use_output_gate: bool = False,
221
+ rand_init: bool = False,
222
+ **kwargs,
223
+ ):
224
+ self.vocab_size = vocab_size
225
+ self.max_position_embeddings = max_position_embeddings
226
+ self.hidden_size = hidden_size
227
+ self.intermediate_size = intermediate_size
228
+ self.num_hidden_layers = num_hidden_layers
229
+ self.num_attention_heads = num_attention_heads
230
+ self.use_sliding_window = use_sliding_window
231
+ self.sliding_window = sliding_window # we check `use_sliding_window` in the modeling code
232
+ self.max_window_layers = max_window_layers
233
+ self.mixer_types = mixer_types
234
+ if len(self.mixer_types) == 0:
235
+ # The default config is Qwen3 (full attn in every layer)
236
+ self.mixer_types = ['attn'] * self.num_hidden_layers
237
+ else:
238
+ self.mixer_types = mixer_types
239
+ assert len(self.mixer_types) == self.num_hidden_layers
240
+
241
+ # for backward compatibility
242
+ if num_key_value_heads is None:
243
+ num_key_value_heads = num_attention_heads
244
+
245
+ self.num_key_value_heads = num_key_value_heads
246
+
247
+ if head_dim is None:
248
+ head_dim = self.hidden_size // self.num_attention_heads
249
+
250
+ # For Lightning Attention
251
+ self.head_dim = head_dim
252
+ self.lightning_use_qk_norm = lightning_use_qk_norm
253
+ self.lightning_use_output_norm = lightning_use_output_norm
254
+ self.lightning_use_output_gate = lightning_use_output_gate
255
+ self.lightning_use_rope = lightning_use_rope
256
+ self.lightning_use_short_conv = lightning_use_short_conv
257
+ self.lightning_conv_size = lightning_conv_size
258
+ self.expand_kv_proj = expand_kv_proj
259
+ self.lightning_rope_scaling = lightning_rope_scaling
260
+ self.lightning_nh = lightning_nh if lightning_nh is not None else self.num_attention_heads
261
+ self.lightning_nkv = lightning_nkv if lightning_nkv is not None else self.num_key_value_heads
262
+ self.lightning_head_dim = lightning_head_dim if lightning_head_dim is not None else self.head_dim
263
+ self.lightning_scale = lightning_scale
264
+ self.attn_use_rope = attn_use_rope
265
+ self.fused_ce_loss = fused_ce_loss
266
+ self.shift_labels = shift_labels
267
+ self.attn_logits_scaling = attn_logits_scaling
268
+ self.attn_use_output_gate = attn_use_output_gate
269
+
270
+ # Kimi Delta Attention
271
+ self.kda_head_dim = kda_head_dim if kda_head_dim is not None else self.head_dim
272
+ self.kda_num_heads = kda_num_heads if kda_num_heads is not None else self.num_attention_heads
273
+ self.kda_use_conv = kda_use_conv
274
+ self.kda_use_qk_norm = kda_use_qk_norm
275
+ self.kda_use_rope = kda_use_rope
276
+
277
+ # Others
278
+ self.hidden_act = hidden_act
279
+ self.initializer_range = initializer_range
280
+ self.rms_norm_eps = rms_norm_eps
281
+ self.use_cache = use_cache
282
+ self.rope_theta = rope_theta
283
+ self.rope_scaling = rope_scaling
284
+ self.attention_bias = attention_bias
285
+ self.attention_dropout = attention_dropout
286
+ # Validate the correctness of rotary position embeddings parameters
287
+ # BC: if there is a 'type' field, move it to 'rope_type'.
288
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
289
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
290
+ rope_config_validation(self)
291
+
292
+ # Gated DeltaNet (GDN)
293
+ self.gdn_use_short_conv = gdn_use_short_conv
294
+ self.gdn_use_gate = gdn_use_gate
295
+ self.gdn_expand_v = gdn_expand_v
296
+ self.gdn_attn_mode = gdn_attn_mode
297
+ self.gdn_fuse_cross_entropy = gdn_fuse_cross_entropy
298
+ self.gdn_activation = gdn_activation
299
+ self.gdn_nh = gdn_nh if gdn_nh is not None else self.num_attention_heads
300
+ self.gdn_nkv = gdn_nkv if gdn_nkv is not None else self.num_key_value_heads
301
+ self.gdn_use_qk_norm = gdn_use_qk_norm
302
+ self.gdn_use_rope = gdn_use_rope
303
+
304
+ # Mamba2
305
+ self.mamba2_n_groups = mamba2_n_groups
306
+ self.mamba2_expand_ratio = mamba2_expand_ratio
307
+ self.mamba2_conv_kernel = mamba2_conv_kernel
308
+ self.mamba2_bias = mamba2_bias
309
+ self.mamba2_hidden_act = mamba2_hidden_act
310
+
311
+ # Other
312
+ self.use_rope = use_rope
313
+ self.attn_sqrtd = attn_sqrtd
314
+ self.loss_fn = loss_fn
315
+ self.rand_init = rand_init
316
+
317
+ super().__init__(
318
+ tie_word_embeddings=tie_word_embeddings,
319
+ _attn_implementation=_attn_implementation,
320
+ **kwargs,
321
+ )
gdn.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+ from einops import rearrange, repeat
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
+ from fla.modules.l2norm import l2_norm
16
+ from fla.ops.gated_delta_rule import (
17
+ chunk_gated_delta_rule,
18
+ fused_recurrent_gated_delta_rule,
19
+ )
20
+ from .configuration_hybrid import HybridConfig
21
+ from .modeling_qwen3 import Qwen3Attention, apply_rotary_pos_emb
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers.processing_utils import Unpack
25
+
26
+ from fla.models.utils import Cache
27
+
28
+
29
+ def elu_p1(x):
30
+ return (F.elu(x, 1., False) + 1.).to(x)
31
+
32
+
33
+ def sum_norm(x):
34
+ return (x / x.sum(-1, keepdim=True)).to(x)
35
+
36
+ # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
37
+
38
+
39
+ class GatedDeltaNet(nn.Module):
40
+ """
41
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
42
+
43
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
44
+ Parameter alloation when use_gate=True:
45
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
46
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
47
+ - Others are ignorably small.
48
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
49
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
50
+
51
+ Parameter allocation when use_gate=False:
52
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
53
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
54
+ - Others are ignorably small.
55
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
56
+
57
+ Args:
58
+ hidden_size (int, Optional):
59
+ The hidden size of the input. Default: 2048.
60
+ expand_v (float, Optional):
61
+ The expansion ratio for the value dim. Default: 2.0.
62
+ head_dim (int, Optional):
63
+ The dimension of each head. Default: 256.
64
+ num_heads (int, Optional):
65
+ The number of heads. Default: 4.
66
+ mode (str, Optional):
67
+ Which Gated DeltaNet kernel to use.
68
+ Currently available: `chunk` and `fused_recurrent`.
69
+ Default: `chunk`.
70
+ use_beta (bool, Optional):
71
+ Whether to use beta. Default: `True`.
72
+ use_gate (bool, Optional):
73
+ Whether to use output gate. Default: `True`.
74
+ use_short_conv (bool, Optional):
75
+ Whether to use short convolutions. Default: `True`.
76
+ conv_size (int, Optional):
77
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
78
+ conv_bias (bool, Optional):
79
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
80
+ layer_idx (int, Optional):
81
+ The index of the layer. Default: None.
82
+ norm_eps (float, Optional):
83
+ The epsilon value for the normalization layer. Default: 1e-5.
84
+ """
85
+
86
+ def __init__(
87
+ self,
88
+ layer_idx: Optional[int] = None,
89
+ hidden_size: int = 2048,
90
+ expand_v: float = 2,
91
+ # head_dim: int = 256,
92
+ key_dim: int = 128,
93
+ val_dim: int = 128,
94
+ num_heads: int = 32,
95
+ num_kv_heads: int = 8,
96
+ mode: str = 'chunk',
97
+ use_gate: bool = True,
98
+ use_short_conv: bool = True,
99
+ conv_size: int = 4,
100
+ conv_bias: bool = False,
101
+ norm_eps: float = 1e-5,
102
+ activation: Optional[str] = None,
103
+ qk_norm: bool = False,
104
+ use_rope: bool = False,
105
+ **kwargs,
106
+ ):
107
+ super().__init__()
108
+
109
+ self.mode = mode
110
+
111
+ self.hidden_size = hidden_size
112
+ self.expand_v = expand_v
113
+
114
+ self.use_gate = use_gate
115
+ self.use_short_conv = use_short_conv
116
+ self.conv_size = conv_size
117
+ self.conv_bias = conv_bias
118
+
119
+ # self.head_dim = head_dim
120
+ self.key_dim = key_dim
121
+ self.val_dim = val_dim
122
+ self.num_heads = num_heads
123
+ self.num_kv_heads = num_kv_heads
124
+
125
+ self.k_dim = self.num_kv_heads * key_dim
126
+ self.v_dim = self.num_kv_heads * val_dim
127
+ self.q_dim = self.num_heads * key_dim
128
+ self.layer_idx = layer_idx
129
+ self.activation = activation
130
+ self.qk_norm = qk_norm
131
+ self.use_rope = use_rope
132
+ self.silu = nn.SiLU()
133
+
134
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
135
+
136
+ if self.qk_norm:
137
+ self.q_norm = RMSNorm(key_dim, eps=norm_eps)
138
+ self.k_norm = RMSNorm(key_dim, eps=norm_eps)
139
+ self.q_proj = nn.Linear(hidden_size, self.q_dim, bias=False)
140
+ self.k_proj = nn.Linear(hidden_size, self.k_dim, bias=False)
141
+ self.v_proj = nn.Linear(hidden_size, self.v_dim, bias=False)
142
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
143
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
144
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
145
+ A_log = torch.log(A)
146
+ self.A_log = nn.Parameter(A_log)
147
+ self.A_log._no_weight_decay = True
148
+ # self.D = nn.Parameter(torch.ones(self.num_heads))
149
+ # self.D._no_weight_decay = True
150
+ # hard coded for now
151
+ dt_min = 0.001
152
+ dt_max = 0.1
153
+ dt_init_floor = 1e-4
154
+ dt = torch.exp(
155
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
156
+ + math.log(dt_min)
157
+ )
158
+ dt = torch.clamp(dt, min=dt_init_floor)
159
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
160
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
161
+ self.dt_bias = nn.Parameter(inv_dt)
162
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
163
+ # name.endswith("bias") in param_grouping.py
164
+ self.dt_bias._no_weight_decay = True
165
+
166
+ if use_short_conv:
167
+ self.conv_size = conv_size
168
+ self.q_conv1d = ShortConvolution(
169
+ hidden_size=self.key_dim,
170
+ kernel_size=conv_size,
171
+ activation='silu',
172
+ use_fast_conv1d=False,
173
+ )
174
+ self.k_conv1d = ShortConvolution(
175
+ hidden_size=self.key_dim,
176
+ kernel_size=conv_size,
177
+ activation='silu',
178
+ use_fast_conv1d=False,
179
+ )
180
+ self.v_conv1d = ShortConvolution(
181
+ hidden_size=self.v_dim,
182
+ kernel_size=conv_size,
183
+ activation='silu',
184
+ use_fast_conv1d=False,
185
+ )
186
+ # else:
187
+ # raise UserWarning(
188
+ # "ShortConvolution is crucial to the performance. "
189
+ # "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
190
+ # )
191
+ if use_gate:
192
+ self.g_proj = nn.Linear(hidden_size, self.num_heads * self.val_dim, bias=False)
193
+ self.o_norm = FusedRMSNormSwishGate(self.val_dim, eps=norm_eps)
194
+ else:
195
+ self.o_norm = RMSNorm(self.val_dim, eps=norm_eps)
196
+ self.o_proj = nn.Linear(self.num_heads * self.val_dim, hidden_size, bias=False)
197
+ self.apply(self._initialize_weights)
198
+
199
+ def _initialize_weights(self, module: nn.Module):
200
+ if getattr(module, "_is_hf_initialized", False):
201
+ return
202
+ if isinstance(module, nn.Linear):
203
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
204
+ if module.bias is not None:
205
+ nn.init.zeros_(module.bias)
206
+ module._is_hf_initialized = True
207
+
208
+ def forward(
209
+ self,
210
+ hidden_states: torch.Tensor,
211
+ attention_mask: Optional[torch.Tensor] = None,
212
+ past_key_values: Optional[Cache] = None,
213
+ use_cache: Optional[bool] = False,
214
+ output_attentions: Optional[bool] = False,
215
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
216
+ **kwargs,
217
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
218
+ attention_mask = None
219
+ if attention_mask is not None:
220
+ assert len(attention_mask.shape) == 2, (
221
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
222
+ "for padding purposes (0 indicating padding). "
223
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
224
+ )
225
+
226
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
227
+ if self.training:
228
+ assert mode == 'chunk', "Only chunk mode is supported in training."
229
+
230
+ last_state = None
231
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
232
+ last_state = past_key_values[self.layer_idx]
233
+
234
+ if self.use_short_conv:
235
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
236
+ if last_state is not None:
237
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
238
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
239
+ q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states),
240
+ mask=conv_mask,
241
+ cache=conv_state_q,
242
+ output_final_state=use_cache)
243
+ k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states),
244
+ mask=conv_mask,
245
+ cache=conv_state_k,
246
+ output_final_state=use_cache)
247
+ v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states),
248
+ mask=conv_mask,
249
+ cache=conv_state_v,
250
+ output_final_state=use_cache)
251
+ else:
252
+ q = self.q_proj(hidden_states)
253
+ k = self.k_proj(hidden_states)
254
+ v = self.v_proj(hidden_states)
255
+ if self.activation is not None:
256
+ q = self.silu(q)
257
+ k = self.silu(k)
258
+ v = self.silu(v)
259
+
260
+ q = rearrange(q, 'b t (h d) -> b t h d', d=self.key_dim)
261
+ k = rearrange(k, 'b t (h d) -> b t h d', d=self.key_dim)
262
+ v = rearrange(v, 'b t (h d) -> b t h d', d=self.val_dim)
263
+
264
+ if self.qk_norm:
265
+ q = self.q_norm(q)
266
+ k = self.k_norm(k)
267
+
268
+ if self.use_rope:
269
+ assert position_embeddings is not None
270
+ cos, sin = position_embeddings
271
+ q, k = q.transpose(1, 2), k.transpose(1, 2)
272
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
273
+ q, k = q.transpose(1, 2), k.transpose(1, 2)
274
+
275
+ q = l2_norm(q)
276
+ k = l2_norm(k)
277
+ # Allow negative eigenvalues
278
+ beta = self.b_proj(hidden_states).sigmoid() * 2
279
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
280
+
281
+ # Handle grouped-query, maybe we should untie the weights to go back to MHA?
282
+ if self.num_kv_heads < self.num_heads:
283
+ group_size = self.num_heads // self.num_kv_heads
284
+ k = repeat(k, 'b t h d -> b t (h g) d', g=group_size) # (B, T, nh, dh)
285
+ v = repeat(v, 'b t h d -> b t (h g) d', g=group_size) # (B, T, nh, dh)
286
+
287
+ # dealing with padding
288
+ if attention_mask is not None:
289
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
290
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
291
+
292
+ recurrent_state = last_state['recurrent_state'] if last_state is not None else None
293
+ # offsets = kwargs.get('offsets', None)
294
+ if mode == 'chunk':
295
+ o, recurrent_state = chunk_gated_delta_rule(
296
+ q=q,
297
+ k=k,
298
+ v=v,
299
+ g=g,
300
+ beta=beta,
301
+ initial_state=recurrent_state,
302
+ output_final_state=use_cache,
303
+ # offsets=offsets,
304
+ # head_first=False
305
+ )
306
+ elif mode == 'fused_recurrent':
307
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
308
+ q=q,
309
+ k=k,
310
+ v=v,
311
+ g=g,
312
+ beta=beta,
313
+ initial_state=recurrent_state,
314
+ output_final_state=use_cache,
315
+ # offsets=offsets,
316
+ # head_first=False
317
+ )
318
+ if past_key_values is not None:
319
+ past_key_values.update(
320
+ recurrent_state=recurrent_state,
321
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
322
+ layer_idx=self.layer_idx,
323
+ offset=q.shape[2]
324
+ )
325
+
326
+ if self.use_gate:
327
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
328
+ o = self.o_norm(o, g)
329
+ else:
330
+ o = self.o_norm(o)
331
+ o = rearrange(o, 'b t h d -> b t (h d)')
332
+ o = self.o_proj(o)
333
+
334
+ return o, None, past_key_values
335
+
336
+
337
+
338
+ def build_gdn_with_attn(
339
+ attn_layer: Qwen3Attention,
340
+ layer_idx: int,
341
+ config: HybridConfig,
342
+ ) -> nn.Module:
343
+ """
344
+ Initialize a Gated DeltaNet block using the parameters of a Qwen3Attention layer.
345
+ We instantiate the GDN block such that the QKVO projections have the same shape,
346
+ then copy the weights from the Qwen3Attention layer.
347
+ """
348
+
349
+ gdn_block = GatedDeltaNet(
350
+ hidden_size=config.hidden_size,
351
+ layer_idx=layer_idx,
352
+ expand_v=1.0,
353
+ num_heads=config.gdn_nh,
354
+ num_kv_heads=config.gdn_nkv,
355
+ key_dim=config.head_dim,
356
+ val_dim=config.head_dim,
357
+ use_short_conv=config.gdn_use_short_conv,
358
+ use_gate=config.gdn_use_gate,
359
+ norm_eps=config.rms_norm_eps,
360
+ activation=config.gdn_activation,
361
+ qk_norm=config.gdn_use_qk_norm,
362
+ use_rope=config.gdn_use_rope,
363
+ )
364
+
365
+ q_proj: nn.Linear = attn_layer.q_proj
366
+ k_proj: nn.Linear = attn_layer.k_proj
367
+ v_proj: nn.Linear = attn_layer.v_proj
368
+ o_proj: nn.Linear = attn_layer.o_proj
369
+ # Note that the `.weight.shape` for a projection from d1 to d2 is (d2, d1)
370
+ wq: Tensor = q_proj.weight # (nh * dh, d)
371
+ wk: Tensor = k_proj.weight # (nkv * dh, d)
372
+ wv: Tensor = v_proj.weight # (nkv * dh, d)
373
+ wo: Tensor = o_proj.weight # (d, nh * dh)
374
+
375
+ if config.expand_kv_proj:
376
+ wk = wk.reshape(-1, config.head_dim, config.hidden_size)
377
+ wv = wv.reshape(-1, config.head_dim, config.hidden_size)
378
+ assert wk.shape[1] == wv.shape[1], wk.shape[1] == config.num_key_value_heads
379
+
380
+ # Repeat KV projections to convert it to MHA
381
+ target_kv_size = config.lightning_nkv * config.lightning_head_dim
382
+ orig_kv_size = config.num_key_value_heads * config.head_dim
383
+ expand_size = target_kv_size // orig_kv_size
384
+ wk = wk.repeat_interleave(expand_size, dim=0)
385
+ wv = wv.repeat_interleave(expand_size, dim=0)
386
+
387
+ wk = wk.reshape(-1, config.hidden_size)
388
+ wv = wv.reshape(-1, config.hidden_size)
389
+
390
+ # ==== Create target module ====
391
+ gdn_block.q_proj.weight.data.copy_(wq)
392
+ gdn_block.k_proj.weight.data.copy_(wk)
393
+ gdn_block.v_proj.weight.data.copy_(wv)
394
+ gdn_block.o_proj.weight.data.copy_(wo)
395
+
396
+ if hasattr(gdn_block, 'q_norm') and hasattr(attn_layer, 'q_norm'):
397
+ gdn_block.q_norm.weight.data.copy_(attn_layer.q_norm.weight.data.clone())
398
+
399
+ if hasattr(gdn_block, 'k_norm') and hasattr(attn_layer, 'k_norm'):
400
+ gdn_block.k_norm.weight.data.copy_(attn_layer.k_norm.weight.data.clone())
401
+
402
+
403
+ return gdn_block
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 151643,
4
+ "eos_token_id": 151645,
5
+ "transformers_version": "4.57.3"
6
+ }
kda.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from torch import nn
3
+ import torch
4
+ from einops import rearrange, repeat
5
+ try:
6
+ from fla.modules import FusedRMSNormGated, ShortConvolution
7
+ from fla.ops.kda import chunk_kda, fused_recurrent_kda
8
+ from fla.ops.kda.gate import fused_kda_gate
9
+ from fla.ops.utils.index import prepare_cu_seqlens_from_mask, prepare_lens_from_mask
10
+ from fla.utils import tensor_cache
11
+ except ImportError:
12
+ raise ImportError("Plese run `pip install -U fla-core`")
13
+ from .configuration_hybrid import HybridConfig
14
+ from .cache import HybridCache
15
+ from .modeling_qwen3 import Qwen3RMSNorm, apply_rotary_pos_emb
16
+
17
+
18
+ def index_first_axis(x, indices):
19
+ other_shape = x.shape[1:]
20
+ second_dim = other_shape.numel()
21
+ return torch.gather(
22
+ rearrange(x, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim),
23
+ ).reshape(-1, *other_shape)
24
+
25
+
26
+ def index_put_first_axis(x, indices, first_axis_dim):
27
+ y = torch.zeros(first_axis_dim, *x.shape[1:], device=x.device, dtype=x.dtype)
28
+ # TODO [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
29
+ y[indices] = x
30
+ # y.scatter_(0, repeat(indices, 'z -> z d', d=x.shape[1]), x)
31
+ return y
32
+
33
+
34
+ @tensor_cache
35
+ def get_unpad_data(
36
+ attention_mask: torch.Tensor,
37
+ ) -> tuple[torch.Tensor, torch.Tensor, int]:
38
+ lens = prepare_lens_from_mask(attention_mask)
39
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
40
+ max_seqlen_in_batch = lens.max().item()
41
+ cu_seqlens = prepare_cu_seqlens_from_mask(attention_mask)
42
+ return indices, cu_seqlens, max_seqlen_in_batch
43
+
44
+
45
+ def unpad_input(
46
+ q: torch.Tensor,
47
+ states: tuple[torch.Tensor],
48
+ attention_mask: torch.Tensor,
49
+ q_len: int,
50
+ keepdim: bool = False,
51
+ ):
52
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = get_unpad_data(attention_mask)
53
+ batch_size, seq_len, *_ = states[0].shape
54
+
55
+ state = tuple(
56
+ index_first_axis(rearrange(s, "b s ... -> (b s) ..."), indices_k)
57
+ for s in states
58
+ )
59
+
60
+ if q_len == seq_len:
61
+ q = index_first_axis(rearrange(q, "b s ... -> (b s) ..."), indices_k)
62
+ cu_seqlens_q = cu_seqlens_k
63
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
64
+ indices_q = indices_k
65
+ elif q_len == 1:
66
+ max_seqlen_in_batch_q = 1
67
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
68
+ indices_q = cu_seqlens_q[:-1]
69
+ q = q.squeeze(1)
70
+ else:
71
+ raise NotImplementedError("We only support either q_len == k_len (prefilling) or q_len == 1 (decoding)")
72
+
73
+ if keepdim:
74
+ q = q.unsqueeze(0)
75
+ state = tuple(s.unsqueeze(0) for s in state)
76
+
77
+ return (
78
+ q,
79
+ state,
80
+ indices_q,
81
+ (cu_seqlens_q, cu_seqlens_k),
82
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
83
+ )
84
+
85
+
86
+ def pad_input(
87
+ hidden_states: torch.Tensor,
88
+ indices: torch.LongTensor,
89
+ batch_size: int,
90
+ seq_len: int,
91
+ ) -> torch.Tensor:
92
+ output = index_put_first_axis(hidden_states, indices, batch_size * seq_len)
93
+ return rearrange(output, "(b s) ... -> b s ...", b=batch_size)
94
+
95
+
96
+ class KimiDeltaAttention(nn.Module):
97
+ def __init__(self, config: HybridConfig, layer_idx: int):
98
+ super().__init__()
99
+ self.config = config
100
+ self.mode = "chunk"
101
+
102
+ self.hidden_size = config.hidden_size
103
+ self.head_dim = config.kda_head_dim
104
+ self.num_heads = config.kda_num_heads
105
+ self.head_k_dim = self.head_dim
106
+ self.num_k_heads = self.num_heads
107
+ self.use_conv = config.kda_use_conv
108
+ self.use_qk_norm = config.kda_use_qk_norm
109
+ self.use_rope = config.kda_use_rope
110
+
111
+ self.layer_idx = layer_idx
112
+
113
+ assert self.mode in [
114
+ 'chunk', 'fused_recurrent'], f"Not suppoerted mode `{self.mode}`."
115
+
116
+ projection_k_size = self.head_k_dim * self.num_k_heads
117
+ projection_size = self.head_dim * self.num_heads
118
+
119
+ self.q_proj = nn.Linear(
120
+ self.hidden_size, projection_k_size, bias=False)
121
+ self.k_proj = nn.Linear(
122
+ self.hidden_size, projection_k_size, bias=False)
123
+ self.v_proj = nn.Linear(self.hidden_size, projection_size, bias=False)
124
+
125
+ if self.use_qk_norm:
126
+ self.q_norm = Qwen3RMSNorm(
127
+ self.head_dim, eps=config.rms_norm_eps)
128
+ self.k_norm = Qwen3RMSNorm(
129
+ self.head_dim, eps=config.rms_norm_eps)
130
+
131
+ if self.use_conv:
132
+ self.conv_size = self.config.kda_conv_size
133
+ self.q_conv1d = ShortConvolution(
134
+ hidden_size=projection_k_size,
135
+ kernel_size=self.conv_size,
136
+ activation='silu',
137
+ )
138
+ self.k_conv1d = ShortConvolution(
139
+ hidden_size=projection_k_size,
140
+ kernel_size=self.conv_size,
141
+ activation='silu',
142
+ )
143
+ self.v_conv1d = ShortConvolution(
144
+ hidden_size=projection_size,
145
+ kernel_size=self.conv_size,
146
+ activation='silu',
147
+ )
148
+
149
+ self.A_log = torch.nn.Parameter(torch.log(torch.empty(
150
+ self.num_heads, dtype=torch.float32).uniform_(1, 16)).view(1, 1, -1, 1))
151
+
152
+ self.f_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
153
+ self.f_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
154
+
155
+ self.dt_bias = nn.Parameter(
156
+ torch.empty(projection_size, dtype=torch.float32))
157
+
158
+ self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
159
+
160
+ self.g_a_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False)
161
+ self.g_b_proj = nn.Linear(self.head_dim, projection_size, bias=False)
162
+
163
+ self.o_norm = FusedRMSNormGated(
164
+ self.head_dim, eps=config.rms_norm_eps, activation='sigmoid')
165
+ self.o_proj = nn.Linear(projection_size, self.hidden_size, bias=False)
166
+
167
+ def forward(
168
+ self,
169
+ hidden_states: torch.Tensor,
170
+ attention_mask: torch.Tensor | None = None,
171
+ past_key_values: HybridCache | None = None,
172
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
173
+ **kwargs,
174
+ ) -> tuple[torch.Tensor, torch.Tensor | None, HybridCache | None]:
175
+ if attention_mask is not None:
176
+ if attention_mask.dim() != 2:
177
+ attention_mask = kwargs.get("padding_mask")
178
+
179
+ if attention_mask is not None and attention_mask.dim() != 2:
180
+ raise ValueError(
181
+ "attention_mask must be a 0-1 matrix of shape [batch_size, seq_len] "
182
+ "(0 = padding). 3D masks are not supported here.",
183
+ )
184
+ use_cache = past_key_values is not None
185
+ batch_size, q_len, _ = hidden_states.shape
186
+ mode = 'fused_recurrent' if q_len <= 64 else self.mode
187
+ if self.training:
188
+ assert mode == 'chunk', "Only chunk mode is supported in training."
189
+
190
+ cu_seqlens = kwargs.get('cu_seqlens')
191
+ indices = None
192
+ if attention_mask is not None:
193
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
194
+ hidden_states = index_first_axis(
195
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices).unsqueeze(0)
196
+
197
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
198
+
199
+ if self.use_conv:
200
+ # Get convolution states from cache
201
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
202
+ conv_state_q, conv_state_k, conv_state_v = past_key_values[self.layer_idx]['conv_state']
203
+
204
+ # Compute short conv
205
+ q, conv_state_q = self.q_conv1d(
206
+ x=self.q_proj(hidden_states),
207
+ cache=conv_state_q,
208
+ output_final_state=use_cache,
209
+ cu_seqlens=cu_seqlens,
210
+ )
211
+ k, conv_state_k = self.k_conv1d(
212
+ x=self.k_proj(hidden_states),
213
+ cache=conv_state_k,
214
+ output_final_state=use_cache,
215
+ cu_seqlens=cu_seqlens,
216
+ )
217
+ v, conv_state_v = self.v_conv1d(
218
+ x=self.v_proj(hidden_states),
219
+ cache=conv_state_v,
220
+ output_final_state=use_cache,
221
+ cu_seqlens=cu_seqlens,
222
+ )
223
+ else:
224
+ q = self.q_proj(hidden_states)
225
+ k = self.k_proj(hidden_states)
226
+ v = self.v_proj(hidden_states)
227
+
228
+ g = self.f_b_proj(self.f_a_proj(hidden_states))
229
+ g = fused_kda_gate(g, self.A_log, self.head_dim, g_bias=self.dt_bias)
230
+ beta = self.b_proj(hidden_states).float().sigmoid()
231
+
232
+ q, k = map(lambda x: rearrange(
233
+ x, '... (h d) -> ... h d', d=self.head_k_dim), (q, k))
234
+ v = rearrange(v, '... (h d) -> ... h d', d=self.head_dim)
235
+
236
+ if self.use_qk_norm:
237
+ q = self.q_norm(q)
238
+ k = self.k_norm(k)
239
+
240
+ if self.use_rope:
241
+ assert (
242
+ position_embeddings is not None
243
+ ), "position_embeddings is required when use_rope is True"
244
+ cos, sin = position_embeddings
245
+ q, k = q.transpose(1, 2), k.transpose(1, 2)
246
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
247
+ q, k = q.transpose(1, 2), k.transpose(1, 2)
248
+
249
+ # Get recurrent state from cache
250
+ recurrent_state = None
251
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
252
+ recurrent_state = past_key_values[self.layer_idx]['recurrent_state']
253
+ if mode == 'chunk':
254
+ o, recurrent_state = chunk_kda(
255
+ q=q,
256
+ k=k,
257
+ v=v,
258
+ g=g,
259
+ beta=beta,
260
+ initial_state=recurrent_state,
261
+ output_final_state=True,
262
+ use_qk_l2norm_in_kernel=True,
263
+ cu_seqlens=cu_seqlens,
264
+ )
265
+ else:
266
+ o, recurrent_state = fused_recurrent_kda(
267
+ q=q,
268
+ k=k,
269
+ v=v,
270
+ g=g,
271
+ beta=beta,
272
+ initial_state=recurrent_state,
273
+ output_final_state=True,
274
+ use_qk_l2norm_in_kernel=True,
275
+ cu_seqlens=cu_seqlens,
276
+ )
277
+ if past_key_values is not None:
278
+ past_key_values.update(
279
+ recurrent_state=recurrent_state,
280
+ conv_state=(conv_state_q, conv_state_k, conv_state_v),
281
+ layer_idx=self.layer_idx,
282
+ )
283
+
284
+ g = self.g_b_proj(self.g_a_proj(hidden_states))
285
+ g = rearrange(g, '... (h d) -> ... h d', d=self.head_dim)
286
+ o = self.o_norm(o, g)
287
+
288
+ o = rearrange(o, 'b t h d -> b t (h d)')
289
+ o = self.o_proj(o)
290
+ if attention_mask is not None:
291
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
292
+
293
+ return o, None, None
294
+
295
+
296
+
297
+ def build_kda_with_attn(
298
+ attn_layer: nn.Module,
299
+ config: HybridConfig,
300
+ layer_idx: int,
301
+ ) -> nn.Module:
302
+
303
+ layer = KimiDeltaAttention(
304
+ config=config,
305
+ layer_idx=layer_idx,
306
+ )
307
+
308
+ # print('============ Lighting attention layer ============')
309
+ # print(f"Layer idx: {layer_idx}")
310
+ # print(layer)
311
+ # print('==================================================')
312
+
313
+ if config.rand_init:
314
+ return layer
315
+
316
+ q_proj = attn_layer.q_proj
317
+ k_proj = attn_layer.k_proj
318
+ v_proj = attn_layer.v_proj
319
+ o_proj = attn_layer.o_proj
320
+
321
+ # (nh * head_dim, hidden_size)
322
+ wq = q_proj.weight.data.clone() # type: ignore
323
+ wk = k_proj.weight.data.clone() # type: ignore
324
+ wv = v_proj.weight.data.clone() # type: ignore
325
+ wo = o_proj.weight.data.clone() # type: ignore
326
+
327
+ if config.expand_kv_proj:
328
+ wk = wk.reshape(-1, config.head_dim, config.hidden_size)
329
+ wv = wv.reshape(-1, config.head_dim, config.hidden_size)
330
+ assert wk.shape[1] == wv.shape[1], wk.shape[1] == config.num_key_value_heads
331
+
332
+ # Repeat KV projections to convert it to MHA
333
+ target_kv_size = config.lightning_nkv * config.lightning_head_dim
334
+ orig_kv_size = config.num_key_value_heads * config.head_dim
335
+ expand_size = target_kv_size // orig_kv_size
336
+ wk = wk.repeat_interleave(expand_size, dim=0)
337
+ wv = wv.repeat_interleave(expand_size, dim=0)
338
+
339
+ wk = wk.reshape(-1, config.hidden_size)
340
+ wv = wv.reshape(-1, config.hidden_size)
341
+
342
+ # print(layer)
343
+ # print(wq.shape)
344
+ # print(wk.shape)
345
+ # print(wv.shape)
346
+ # print(wo.shape)
347
+ # print(layer.q_proj.weight.shape)
348
+ # print(layer.k_proj.weight.shape)
349
+ # print(layer.v_proj.weight.shape)
350
+ # print(layer.o_proj.weight.shape)
351
+ # exit()
352
+
353
+ layer.q_proj.weight.data.copy_(wq)
354
+ layer.k_proj.weight.data.copy_(wk)
355
+ layer.v_proj.weight.data.copy_(wv)
356
+ layer.o_proj.weight.data.copy_(wo)
357
+
358
+ if hasattr(attn_layer, 'k_norm'):
359
+ k_norm_weights = attn_layer.k_norm.weight.data.clone()
360
+ layer.k_norm.weight.data.copy_(k_norm_weights)
361
+
362
+ if hasattr(layer, 'q_norm'):
363
+ q_norm_weights = attn_layer.q_norm.weight.data.clone()
364
+ layer.q_norm.weight.data.copy_(q_norm_weights)
365
+
366
+ return layer
lightning_attn.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn, Tensor
3
+ from typing import Optional, Tuple
4
+ from einops import rearrange, repeat
5
+ import math
6
+ from transformers.utils import logging
7
+
8
+ import torch.nn.functional as F
9
+
10
+ from fla.ops.simple_gla import chunk_simple_gla, fused_chunk_simple_gla
11
+ from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla
12
+ from .modeling_qwen3 import Qwen3RMSNorm
13
+ from .configuration_hybrid import HybridConfig
14
+ from .modeling_qwen3 import apply_rotary_pos_emb
15
+ from .cache import HybridCache
16
+ from fla.modules import ShortConvolution
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ def _build_slope_tensor(nheads: int):
23
+ def get_slopes(n):
24
+ def get_slopes_power_of_2(n):
25
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
26
+ ratio = start
27
+ return [start * ratio**i for i in range(n)]
28
+
29
+ if math.log2(n).is_integer():
30
+ return get_slopes_power_of_2(
31
+ n
32
+ ) # In the paper, we only train models that have 2^a heads for some a. This function has
33
+ else: # some good properties that only occur when the input is a power of 2. To maintain that even
34
+ closest_power_of_2 = 2 ** math.floor(
35
+ math.log2(n)
36
+ ) # when the number of heads is not a power of 2, we use this workaround.
37
+ return (
38
+ get_slopes_power_of_2(closest_power_of_2)
39
+ + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
40
+ )
41
+
42
+ slopes = torch.tensor(get_slopes(nheads)) # (nheads,)
43
+ return slopes
44
+
45
+
46
+ class LightningAttention(nn.Module):
47
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
48
+
49
+ def __init__(
50
+ self,
51
+ layer_idx: int,
52
+ hidden_size: int,
53
+ num_attention_heads: int,
54
+ num_key_value_heads: int,
55
+ head_dim: int,
56
+ attention_dropout: float = 0.0,
57
+ use_output_gate: bool = False,
58
+ use_short_conv: bool = False,
59
+ conv_size: int = 4,
60
+ attention_bias: bool = False,
61
+ rms_norm_eps: float = 1e-6,
62
+ use_rope: bool = False,
63
+ # attn_sqrtd: bool = True,
64
+ use_output_norm: bool = False,
65
+ qk_norm: bool = True,
66
+ rope_head_dim: Optional[int] = None,
67
+ # div_d: bool = False,
68
+ scale: str = '1/sqrt(d)',
69
+ ):
70
+ super().__init__()
71
+ self.layer_idx = layer_idx
72
+ self.hidden_size = hidden_size
73
+ self.num_attention_heads = num_attention_heads
74
+ self.num_key_value_heads = num_key_value_heads
75
+ self.num_key_value_groups = num_attention_heads // num_key_value_heads
76
+ self.head_dim = head_dim
77
+ if scale == '1/sqrt(d)':
78
+ self.scale = self.head_dim ** (-0.5)
79
+ elif scale == '1/d':
80
+ self.scale = self.head_dim ** (-1.0)
81
+ else:
82
+ self.scale = 1.0
83
+ self.attention_dropout = attention_dropout
84
+ self.is_causal = True
85
+ self.use_output_gate = use_output_gate
86
+ self.attention_bias = attention_bias
87
+ self.rms_norm_eps = rms_norm_eps
88
+ self.use_rope = use_rope
89
+ self.qk_norm = qk_norm
90
+ self.use_output_norm = use_output_norm
91
+ self.rope_head_dim = rope_head_dim if rope_head_dim is not None else head_dim
92
+ assert self.rope_head_dim <= self.head_dim
93
+ self.use_short_conv = use_short_conv
94
+ self.conv_size = conv_size
95
+
96
+ self.q_proj = nn.Linear(
97
+ self.hidden_size,
98
+ self.num_attention_heads * self.head_dim,
99
+ bias=self.attention_bias,
100
+ )
101
+ self.k_proj = nn.Linear(
102
+ self.hidden_size,
103
+ self.num_key_value_heads * self.head_dim,
104
+ bias=self.attention_bias,
105
+ )
106
+ self.v_proj = nn.Linear(
107
+ self.hidden_size,
108
+ self.num_key_value_heads * self.head_dim,
109
+ bias=self.attention_bias,
110
+ )
111
+ self.o_proj = nn.Linear(
112
+ self.num_attention_heads * self.head_dim,
113
+ self.hidden_size,
114
+ bias=self.attention_bias,
115
+ )
116
+ if self.use_output_norm:
117
+ self.o_norm = Qwen3RMSNorm(
118
+ hidden_size=self.num_attention_heads * self.head_dim,
119
+ eps=self.rms_norm_eps,
120
+ )
121
+
122
+ if self.use_output_gate:
123
+ self.z_proj = nn.Linear(
124
+ self.hidden_size,
125
+ self.num_attention_heads * self.head_dim,
126
+ bias=self.attention_bias,
127
+ )
128
+
129
+ if self.qk_norm:
130
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=self.rms_norm_eps)
131
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=self.rms_norm_eps)
132
+
133
+ if self.use_short_conv:
134
+ self.conv_size = conv_size
135
+ self.q_conv1d = ShortConvolution(
136
+ hidden_size=self.num_attention_heads * self.hidden_size,
137
+ kernel_size=conv_size,
138
+ activation='silu',
139
+ use_fast_conv1d=False,
140
+ )
141
+ self.k_conv1d = ShortConvolution(
142
+ hidden_size=self.num_key_value_heads * self.hidden_size,
143
+ kernel_size=conv_size,
144
+ activation='silu',
145
+ use_fast_conv1d=False,
146
+ )
147
+ self.v_conv1d = ShortConvolution(
148
+ hidden_size=self.num_key_value_heads * self.hidden_size,
149
+ kernel_size=conv_size,
150
+ activation='silu',
151
+ use_fast_conv1d=False,
152
+ )
153
+
154
+ def attn_fn(
155
+ self,
156
+ q: Tensor, # (b, t, h, d)
157
+ k: Tensor, # (b, t, h, d)
158
+ v: Tensor, # (b, t, h, d)
159
+ decay: Tensor, # (h,)
160
+ scale: float | None = None, # will use dk^(-1) if None.
161
+ initial_state: Tensor | None = None, # (b, h, dk, dv)
162
+ mode: str = 'chunk',
163
+ ) -> tuple[Tensor, Tensor]:
164
+ seqlen = q.shape[1]
165
+ mode = "fused_recurrent" if seqlen < 64 else "chunk"
166
+ if mode == "chunk":
167
+ o, final_state = fused_chunk_simple_gla(
168
+ q=q,
169
+ k=k,
170
+ v=v,
171
+ g_gamma=decay, # (h,)
172
+ initial_state=initial_state,
173
+ output_final_state=True,
174
+ scale=scale,
175
+ # head_first=False,
176
+ ) # (b, t, h, d)
177
+ elif mode == "fused_recurrent":
178
+ o, final_state = fused_recurrent_simple_gla(
179
+ q=q,
180
+ k=k,
181
+ v=v,
182
+ g_gamma=decay,
183
+ scale=scale,
184
+ initial_state=initial_state,
185
+ output_final_state=True,
186
+ # reverse=reverse,
187
+ # cu_seqlens=cu_seqlens,
188
+ # head_first=False,
189
+ )
190
+ else:
191
+ raise ValueError(f"Invalid mode: {mode}")
192
+ # else:
193
+ # print('recurrent')
194
+ # # Recurrent
195
+ # if S is None:
196
+ # b = k.shape[0]
197
+ # h = k.shape[1]
198
+ # dk = k.shape[3]
199
+ # dv = v.shape[3]
200
+ # S = torch.zeros(b, h, dk, dv, device=q.device, dtype=torch.float32)
201
+ # q = q.to(torch.float32)
202
+ # k = k.to(torch.float32)
203
+ # v = v.to(torch.float32)
204
+ # if self.attn_sqrtd:
205
+ # k = k * self.scaling
206
+ # ys = []
207
+ # s = torch.exp(s) # (h)
208
+ # for i in range(seqlen):
209
+ # qi = q[:, :, i, :]
210
+ # ki = k[:, :, i, :]
211
+ # vi = v[:, :, i, :]
212
+ # S = einsum(S, s, "b h dk dv, h -> b h dk dv")
213
+ # S = S + einsum(ki, vi, "b h dk, b h dv -> b h dk dv")
214
+ # yi = einsum(qi, S, "b h dk, b h dk dv -> b h dv")
215
+ # ys.append(yi)
216
+ # past_key_values.update(
217
+ # recurrent_state=S, layer_idx=self.layer_idx, offset=seqlen
218
+ # )
219
+ # o = torch.stack(ys, dim=2) # (b, h, t, d)
220
+ # # print('=' * 100)
221
+ # # print(o.shape)
222
+ # o = rearrange(o, "b h t d -> b t (h d)").contiguous()
223
+ # o = o.to(hidden_states.dtype) # (b, t, d)
224
+
225
+ return o, final_state
226
+
227
+ def forward(
228
+ self,
229
+ hidden_states: torch.Tensor,
230
+ position_ids: Optional[torch.LongTensor] = None,
231
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ past_key_values: Optional[HybridCache] = None,
234
+ use_cache: Optional[bool] = False,
235
+ # cache_position: Optional[torch.LongTensor] = None,
236
+ **kwargs,
237
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[HybridCache]]:
238
+ attention_mask = None
239
+ bsz, seqlen, _ = hidden_states.shape
240
+
241
+ last_state = None
242
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
243
+ last_state = past_key_values[self.layer_idx]
244
+
245
+ # print('============ Lightning attention input ============')
246
+ # print(hidden_states.shape)
247
+
248
+ q = self.q_proj(hidden_states)
249
+ k = self.k_proj(hidden_states)
250
+ v = self.v_proj(hidden_states)
251
+ if self.use_short_conv:
252
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
253
+ if last_state is not None:
254
+ conv_state_q, conv_state_k, conv_state_v = last_state['conv_state']
255
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
256
+ q, conv_state_q = self.q_conv1d(x=q,
257
+ mask=conv_mask,
258
+ cache=conv_state_q,
259
+ output_final_state=use_cache)
260
+ k, conv_state_k = self.k_conv1d(x=k,
261
+ mask=conv_mask,
262
+ cache=conv_state_k,
263
+ output_final_state=use_cache)
264
+ v, conv_state_v = self.v_conv1d(x=v,
265
+ mask=conv_mask,
266
+ cache=conv_state_v,
267
+ output_final_state=use_cache)
268
+
269
+ # print('============ Lightning attention after short conv ============')
270
+ # print(q.shape, k.shape, v.shape)
271
+
272
+ q = rearrange(q, "b t (h d) -> b t h d", d=self.head_dim)
273
+ k = rearrange(k, "b t (h d) -> b t h d", d=self.head_dim)
274
+ v = rearrange(v, "b t (h d) -> b t h d", d=self.head_dim)
275
+ # print('============ Lightning attention input after rearrange ============')
276
+ # print(q.shape, k.shape, v.shape)
277
+
278
+ if self.qk_norm:
279
+ q = self.q_norm(q)
280
+ k = self.k_norm(k)
281
+
282
+ if self.use_rope:
283
+ assert (
284
+ position_embeddings is not None
285
+ ), "position_embeddings is required when use_rope is True"
286
+ cos, sin = position_embeddings
287
+
288
+ # (B, T, H, D) -> (B, H, T, D)
289
+ # q, k = q.transpose(1, 2), k.transpose(1, 2)
290
+ q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=2)
291
+ # (B, H, T, D) -> (B, T, H, D)
292
+ # q, k = q.transpose(1, 2), k.transpose(1, 2)
293
+ # Rearrange QK to match RoPE's head dim
294
+ # rope_dim_not_match = q.shape[-1] != self.rope_head_dim
295
+ # if rope_dim_not_match:
296
+ # orig_nq = q.shape[1]
297
+ # orig_nk = k.shape[1]
298
+ # q = rearrange(q, "b h t (h2 d) -> b (h h2) t d", d=self.rope_head_dim)
299
+ # k = rearrange(k, "b h t (h2 d) -> b (h h2) t d", d=self.rope_head_dim)
300
+
301
+ # q, k = apply_rotary_pos_emb(q, k, cos, sin)
302
+
303
+ # if rope_dim_not_match:
304
+ # q = rearrange(q, "b (h h2) t d -> b h t (h2 d)", h=orig_nq)
305
+ # k = rearrange(k, "b (h h2) t d -> b h t (h2 d)", h=orig_nk)
306
+
307
+ if self.num_key_value_heads < self.num_attention_heads:
308
+ group_size = self.num_attention_heads // self.num_key_value_heads
309
+ k = repeat(k, 'b t h d -> b t (h g) d', g=group_size) # (B, T, nh, dh)
310
+ v = repeat(v, 'b t h d -> b t (h g) d', g=group_size) # (B, T, nh, dh)
311
+
312
+ s = (
313
+ _build_slope_tensor(self.num_attention_heads).to(
314
+ k.device, dtype=torch.float32
315
+ )
316
+ * (-1.0)
317
+ ) # (h)
318
+
319
+ initial_state = None
320
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
321
+ layer_state = past_key_values[self.layer_idx]
322
+ initial_state = layer_state['recurrent_state']
323
+
324
+ # q = rearrange(q, "b h t d -> b t h d").to(torch.float32)
325
+ # k = rearrange(k, "b h t d -> b t h d").to(torch.float32)
326
+ # v = rearrange(v, "b h t d -> b t h d").to(torch.float32)
327
+ q = q.to(torch.float32)
328
+ k = k.to(torch.float32)
329
+ v = v.to(torch.float32)
330
+ s = s.to(torch.float32)
331
+
332
+ o, final_state = self.attn_fn(
333
+ q=q,
334
+ k=k,
335
+ v=v,
336
+ decay=s,
337
+ initial_state=initial_state,
338
+ scale=self.scale,
339
+ )
340
+
341
+ # print('============ Lightning attention output after attn_fn ============')
342
+ # print(o.shape)
343
+
344
+ if past_key_values is not None:
345
+ past_key_values.update(
346
+ recurrent_state=final_state,
347
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
348
+ layer_idx=self.layer_idx,
349
+ offset=seqlen,
350
+ )
351
+
352
+ o = rearrange(o, "b t h d -> b t (h d)").contiguous().to(hidden_states.dtype) # (b, t, d)
353
+
354
+ # print('============ Lightning attention output after rearrange ============')
355
+ # print(f"output shape: {o.shape}")
356
+ if self.use_output_norm:
357
+ o = self.o_norm(o) # (b, t, d)
358
+
359
+ if self.use_output_gate:
360
+ z = F.sigmoid(self.z_proj(hidden_states)) # (b, t, d)
361
+ o = o * z # (b, t, d)
362
+
363
+ y = self.o_proj(o)
364
+ return y, None, past_key_values
365
+
366
+
367
+ def build_lightning_attn_with_attn(
368
+ attn_layer: nn.Module,
369
+ config: HybridConfig,
370
+ layer_idx: int,
371
+ ) -> nn.Module:
372
+
373
+ layer = LightningAttention(
374
+ layer_idx,
375
+ hidden_size=config.hidden_size,
376
+ num_attention_heads=config.lightning_nh,
377
+ num_key_value_heads=config.lightning_nkv,
378
+ head_dim=config.lightning_head_dim,
379
+ attention_dropout=config.attention_dropout,
380
+ use_output_gate=config.lightning_use_output_gate,
381
+ use_output_norm=config.lightning_use_output_norm,
382
+ attention_bias=config.attention_bias,
383
+ rms_norm_eps=config.rms_norm_eps,
384
+ use_rope=config.lightning_use_rope,
385
+ # attn_sqrtd=config.attn_sqrtd,
386
+ qk_norm=config.lightning_use_qk_norm,
387
+ rope_head_dim=config.head_dim,
388
+ scale=config.lightning_scale,
389
+ use_short_conv=config.lightning_use_short_conv,
390
+ conv_size=config.lightning_conv_size,
391
+ )
392
+
393
+ # print('============ Lighting attention layer ============')
394
+ # print(f"Layer idx: {layer_idx}")
395
+ # print(layer)
396
+ # print('==================================================')
397
+
398
+ if config.rand_init:
399
+ return layer
400
+
401
+ q_proj = attn_layer.q_proj
402
+ k_proj = attn_layer.k_proj
403
+ v_proj = attn_layer.v_proj
404
+ o_proj = attn_layer.o_proj
405
+
406
+ # (nh * head_dim, hidden_size)
407
+ wq = q_proj.weight.data.clone() # type: ignore
408
+ wk = k_proj.weight.data.clone() # type: ignore
409
+ wv = v_proj.weight.data.clone() # type: ignore
410
+ wo = o_proj.weight.data.clone() # type: ignore
411
+
412
+ if config.expand_kv_proj:
413
+ wk = wk.reshape(-1, config.head_dim, config.hidden_size)
414
+ wv = wv.reshape(-1, config.head_dim, config.hidden_size)
415
+ assert wk.shape[1] == wv.shape[1], wk.shape[1] == config.num_key_value_heads
416
+
417
+ # Repeat KV projections to convert it to MHA
418
+ target_kv_size = config.lightning_nkv * config.lightning_head_dim
419
+ orig_kv_size = config.num_key_value_heads * config.head_dim
420
+ expand_size = target_kv_size // orig_kv_size
421
+ wk = wk.repeat_interleave(expand_size, dim=0)
422
+ wv = wv.repeat_interleave(expand_size, dim=0)
423
+
424
+ wk = wk.reshape(-1, config.hidden_size)
425
+ wv = wv.reshape(-1, config.hidden_size)
426
+
427
+ # print(layer)
428
+ # print(wq.shape)
429
+ # print(wk.shape)
430
+ # print(wv.shape)
431
+ # print(wo.shape)
432
+ # print(layer.q_proj.weight.shape)
433
+ # print(layer.k_proj.weight.shape)
434
+ # print(layer.v_proj.weight.shape)
435
+ # print(layer.o_proj.weight.shape)
436
+ # exit()
437
+
438
+ layer.q_proj.weight.data.copy_(wq)
439
+ layer.k_proj.weight.data.copy_(wk)
440
+ layer.v_proj.weight.data.copy_(wv)
441
+ layer.o_proj.weight.data.copy_(wo)
442
+
443
+ if hasattr(attn_layer, 'k_norm') and hasattr(layer, 'k_norm'):
444
+ k_norm_weights = attn_layer.k_norm.weight.data.clone()
445
+ layer.k_norm.weight.data.copy_(k_norm_weights)
446
+
447
+ if hasattr(attn_layer, 'q_norm') and hasattr(layer, 'q_norm'):
448
+ q_norm_weights = attn_layer.q_norm.weight.data.clone()
449
+ layer.q_norm.weight.data.copy_(q_norm_weights)
450
+
451
+ return layer
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a99b8a7e7db7353157c188c9f210460f8aa740a581730ff677057dbf88229d5
3
+ size 3852319072
modeling_hybrid.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple, Union, List, Dict, Any
2
+ from functools import partial
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from transformers.activations import ACT2FN
9
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
10
+ from transformers.generation import GenerationMixin
11
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
12
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast,
16
+ )
17
+ from cut_cross_entropy import linear_cross_entropy
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.processing_utils import Unpack
20
+ from transformers.utils import auto_docstring, can_return_tuple, logging, is_torch_flex_attn_available
21
+ from .configuration_hybrid import HybridConfig
22
+ from .modeling_qwen3 import Qwen3RMSNorm, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding
23
+ from .gdn import GatedDeltaNet
24
+ # from .mamba2 import Mamba2Mixer
25
+ from .lightning_attn import LightningAttention
26
+ from .cache import HybridCache
27
+ # from .kda import KimiDeltaAttention
28
+
29
+ if is_torch_flex_attn_available():
30
+ from torch.nn.attention.flex_attention import BlockMask
31
+
32
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+
38
+ class HybridDecoderLayer(nn.Module):
39
+ def __init__(self, config: HybridConfig, layer_idx: int):
40
+ super().__init__()
41
+ self.config = config
42
+ self.hidden_size = config.hidden_size
43
+ self.layer_idx = layer_idx
44
+ mixer_type = config.mixer_types[layer_idx]
45
+ self.mixer_type = mixer_type
46
+ if mixer_type == 'attn':
47
+ self.self_attn = Qwen3Attention(
48
+ config=config,
49
+ layer_idx=layer_idx,
50
+ )
51
+ elif mixer_type == 'mamba2':
52
+ self.self_attn = Mamba2Mixer(
53
+ layer_idx=layer_idx,
54
+ hidden_size=config.hidden_size,
55
+ num_heads=config.num_attention_heads,
56
+ n_groups=config.mamba2_n_groups,
57
+ expand_ratio=config.mamba2_expand_ratio,
58
+ conv_kernel=config.mamba2_conv_kernel,
59
+ state_size=config.head_dim,
60
+ head_dim=config.head_dim,
61
+ use_bias=config.mamba2_bias,
62
+ hidden_act=config.mamba2_hidden_act,
63
+ )
64
+ elif mixer_type == 'gdn':
65
+ self.self_attn = GatedDeltaNet(
66
+ layer_idx=layer_idx,
67
+ hidden_size=config.hidden_size,
68
+ expand_v=config.gdn_expand_v,
69
+ num_heads=config.gdn_nh,
70
+ num_kv_heads=config.gdn_nkv,
71
+ key_dim=config.head_dim,
72
+ val_dim=config.head_dim,
73
+ use_gate=config.gdn_use_gate,
74
+ use_short_conv=config.gdn_use_short_conv,
75
+ activation=config.gdn_activation,
76
+ qk_norm=config.gdn_use_qk_norm,
77
+ use_rope=config.gdn_use_rope,
78
+ )
79
+ elif mixer_type == 'gla':
80
+ raise NotImplementedError("GatedLightningAttention is not implemented")
81
+ self.self_attn = GatedLinearAttention(config=config, layer_idx=layer_idx)
82
+ elif mixer_type in ['lightning-attn', 'lightning_attn']:
83
+ # raise NotImplementedError("LightningAttention is not implemented")
84
+ self.self_attn = LightningAttention(
85
+ layer_idx=layer_idx,
86
+ hidden_size=config.hidden_size,
87
+ num_attention_heads=config.lightning_nh,
88
+ num_key_value_heads=config.lightning_nkv,
89
+ head_dim=config.lightning_head_dim,
90
+ attention_dropout=config.attention_dropout,
91
+ use_output_gate=config.lightning_use_output_gate,
92
+ attention_bias=config.attention_bias,
93
+ rms_norm_eps=config.rms_norm_eps,
94
+ use_rope=config.lightning_use_rope,
95
+ use_output_norm=config.lightning_use_output_norm,
96
+ qk_norm=config.lightning_use_qk_norm,
97
+ scale=config.lightning_scale,
98
+ use_short_conv=config.lightning_use_short_conv,
99
+ conv_size=config.lightning_conv_size,
100
+ )
101
+ elif mixer_type == 'kda':
102
+ self.self_attn = KimiDeltaAttention(config=config, layer_idx=layer_idx)
103
+ elif mixer_type == 'rwkv7':
104
+ raise NotImplementedError("RWKV7Attention is not implemented")
105
+ # self.self_attn = RWKV7Attention(config=config, layer_idx=layer_idx)
106
+ else:
107
+ raise ValueError(f"Invalid mixer type: {mixer_type}")
108
+ self.mlp = Qwen3MLP(config)
109
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
110
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
111
+ if (
112
+ config.sliding_window and config._attn_implementation != "flash_attention_2"
113
+ ): # diff with Llama is this warning
114
+ logger.warning_once(
115
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
116
+ "unexpected results may be encountered."
117
+ )
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.Tensor,
122
+ attention_mask: Optional[torch.Tensor] = None,
123
+ position_ids: Optional[torch.LongTensor] = None,
124
+ past_key_values: Optional[Cache] = None,
125
+ output_attentions: Optional[bool] = False,
126
+ use_cache: Optional[bool] = False,
127
+ cache_position: Optional[torch.LongTensor] = None,
128
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
129
+ **kwargs: Unpack[FlashAttentionKwargs],
130
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor, Cache]]]:
131
+
132
+ # ==== Time mixing ====
133
+ residual = hidden_states
134
+ hidden_states = self.input_layernorm(hidden_states)
135
+
136
+ # Position embeddings, depends on mixer type and config
137
+ if self.mixer_type == "attn" and not self.config.attn_use_rope:
138
+ position_embeddings = None
139
+ elif self.mixer_type == "lightning-attn" and not self.config.lightning_use_rope:
140
+ position_embeddings = None
141
+ elif self.mixer_type == "kda" and not self.config.kda_use_rope:
142
+ position_embeddings = None
143
+ elif self.mixer_type == "gdn" and not self.config.gdn_use_rope:
144
+ position_embeddings = None
145
+
146
+ # TODO: Also handle other kinds of token mixers
147
+ hidden_states, self_attn_weights, past_key_values = self.self_attn(
148
+ hidden_states=hidden_states,
149
+ attention_mask=attention_mask,
150
+ position_ids=position_ids,
151
+ past_key_values=past_key_values,
152
+ output_attentions=output_attentions,
153
+ use_cache=use_cache,
154
+ cache_position=cache_position,
155
+ position_embeddings=position_embeddings,
156
+ **kwargs,
157
+ )
158
+ hidden_states = residual + hidden_states
159
+
160
+ # ==== Channel mixing ====
161
+ residual = hidden_states
162
+ hidden_states = self.post_attention_layernorm(hidden_states)
163
+ hidden_states = self.mlp(hidden_states)
164
+ hidden_states = residual + hidden_states
165
+
166
+ outputs = (hidden_states, self_attn_weights, past_key_values)
167
+
168
+ return outputs
169
+
170
+
171
+ # @auto_docstring
172
+ class HybridPreTrainedModel(PreTrainedModel):
173
+ config_class = HybridConfig
174
+ base_model_prefix = "model"
175
+ supports_gradient_checkpointing = True
176
+ _no_split_modules = ["HybridDecoderLayer"]
177
+ _skip_keys_device_placement = ["past_key_values"]
178
+ _supports_flash_attn_2 = True
179
+ _supports_sdpa = True
180
+ _supports_flex_attn = True
181
+ _supports_cache_class = True
182
+ _supports_quantized_cache = True
183
+ _supports_static_cache = True
184
+ _supports_attention_backend = True
185
+
186
+ def _init_weights(self, module: nn.Module):
187
+ std = self.config.initializer_range
188
+ if isinstance(module, nn.Linear):
189
+ module.weight.data.normal_(mean=0.0, std=std)
190
+ if module.bias is not None:
191
+ module.bias.data.zero_()
192
+ elif isinstance(module, nn.Embedding):
193
+ module.weight.data.normal_(mean=0.0, std=std)
194
+ if module.padding_idx is not None:
195
+ module.weight.data[module.padding_idx].zero_()
196
+ elif isinstance(module, Qwen3RMSNorm):
197
+ module.weight.data.fill_(1.0)
198
+
199
+
200
+ # @auto_docstring
201
+ class HybridModel(HybridPreTrainedModel):
202
+ def __init__(self, config: HybridConfig):
203
+ super().__init__(config)
204
+ self.padding_idx = config.pad_token_id
205
+ self.vocab_size = config.vocab_size
206
+
207
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
208
+ self.layers = nn.ModuleList(
209
+ [HybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
210
+ )
211
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
212
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
213
+ self.gradient_checkpointing = False
214
+
215
+ # Initialize weights and apply final processing
216
+ self.post_init()
217
+
218
+ def get_input_embeddings(self):
219
+ return self.embed_tokens
220
+
221
+ def set_input_embeddings(self, value):
222
+ self.embed_tokens = value
223
+
224
+ @can_return_tuple
225
+ @auto_docstring
226
+ def forward(
227
+ self,
228
+ input_ids: Optional[torch.LongTensor] = None,
229
+ attention_mask: Optional[torch.Tensor] = None,
230
+ position_ids: Optional[torch.LongTensor] = None,
231
+ past_key_values: Optional[Cache] = None,
232
+ inputs_embeds: Optional[torch.FloatTensor] = None,
233
+ use_cache: Optional[bool] = None,
234
+ output_attentions: Optional[bool] = None,
235
+ output_hidden_states: Optional[bool] = None,
236
+ cache_position: Optional[torch.LongTensor] = None,
237
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
238
+ ) -> BaseModelOutputWithPast:
239
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
240
+ output_hidden_states = (
241
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
242
+ )
243
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
244
+
245
+ if (input_ids is None) ^ (inputs_embeds is not None):
246
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
247
+
248
+ if self.gradient_checkpointing and self.training and use_cache:
249
+ logger.warning_once(
250
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
251
+ )
252
+ use_cache = False
253
+
254
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
255
+ if not isinstance(past_key_values, (type(None), Cache)):
256
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
257
+
258
+ if inputs_embeds is None:
259
+ inputs_embeds = self.embed_tokens(input_ids)
260
+
261
+ if use_cache:
262
+ if past_key_values is None or isinstance(past_key_values, DynamicCache):
263
+ past_key_values = HybridCache()
264
+
265
+ if cache_position is None:
266
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
267
+ cache_position = torch.arange(
268
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
269
+ )
270
+
271
+ if position_ids is None:
272
+ position_ids = cache_position.unsqueeze(0)
273
+
274
+ causal_mask = self._update_causal_mask(
275
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
276
+ )
277
+
278
+ hidden_states = inputs_embeds
279
+
280
+ # create position embeddings to be shared across the decoder layers
281
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
282
+
283
+ # decoder layers
284
+ all_hidden_states = () if output_hidden_states else None
285
+ all_self_attns = () if output_attentions else None
286
+
287
+ for decoder_layer in self.layers:
288
+ if output_hidden_states:
289
+ all_hidden_states += (hidden_states,)
290
+
291
+ if self.gradient_checkpointing and self.training:
292
+ layer_fwd = partial(
293
+ checkpoint,
294
+ decoder_layer,
295
+ use_reentrant=False,
296
+ )
297
+ else:
298
+ layer_fwd = decoder_layer
299
+
300
+ layer_outputs = layer_fwd(
301
+ hidden_states,
302
+ attention_mask=causal_mask,
303
+ position_ids=position_ids,
304
+ past_key_values=past_key_values,
305
+ output_attentions=output_attentions,
306
+ use_cache=use_cache,
307
+ cache_position=cache_position,
308
+ position_embeddings=position_embeddings,
309
+ **flash_attn_kwargs,
310
+ )
311
+
312
+ hidden_states = layer_outputs[0]
313
+
314
+ if output_attentions:
315
+ all_self_attns += (layer_outputs[1],)
316
+
317
+ hidden_states = self.norm(hidden_states)
318
+
319
+ # add hidden states from the last decoder layer
320
+ if output_hidden_states:
321
+ all_hidden_states += (hidden_states,)
322
+
323
+ return BaseModelOutputWithPast(
324
+ last_hidden_state=hidden_states,
325
+ past_key_values=past_key_values if use_cache else None,
326
+ hidden_states=all_hidden_states,
327
+ attentions=all_self_attns,
328
+ )
329
+
330
+ def _update_causal_mask(
331
+ self,
332
+ attention_mask: Union[torch.Tensor, "BlockMask"],
333
+ input_tensor: torch.Tensor,
334
+ cache_position: torch.Tensor,
335
+ past_key_values: Cache,
336
+ output_attentions: bool = False,
337
+ ):
338
+ if self.config._attn_implementation == "flash_attention_2":
339
+ if attention_mask is not None and past_key_values is not None:
340
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
341
+ if is_padding_right:
342
+ raise ValueError(
343
+ "You are attempting to perform batched generation with padding_side='right'"
344
+ " this may lead to unexpected behaviour for Flash Attention version of Hybrid. Make sure to "
345
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
346
+ )
347
+ if attention_mask is not None and 0.0 in attention_mask:
348
+ return attention_mask
349
+ return None
350
+ if self.config._attn_implementation == "flex_attention":
351
+ if isinstance(attention_mask, torch.Tensor):
352
+ attention_mask = make_flex_block_causal_mask(attention_mask)
353
+ return attention_mask
354
+
355
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
356
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
357
+ # to infer the attention mask.
358
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
359
+ using_static_cache = isinstance(past_key_values, StaticCache)
360
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
361
+
362
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
363
+ if (
364
+ self.config._attn_implementation == "sdpa"
365
+ and not (using_static_cache or using_sliding_window_cache)
366
+ and not output_attentions
367
+ ):
368
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
369
+ attention_mask,
370
+ inputs_embeds=input_tensor,
371
+ past_key_values_length=past_seen_tokens,
372
+ sliding_window=self.config.sliding_window,
373
+ is_training=self.training,
374
+ ):
375
+ return None
376
+
377
+ dtype = input_tensor.dtype
378
+ min_dtype = torch.finfo(dtype).min
379
+ sequence_length = input_tensor.shape[1]
380
+ # SlidingWindowCache or StaticCache
381
+ if using_sliding_window_cache or using_static_cache:
382
+ target_length = past_key_values.get_max_cache_shape()
383
+ # DynamicCache or no cache
384
+ else:
385
+ target_length = (
386
+ attention_mask.shape[-1]
387
+ if isinstance(attention_mask, torch.Tensor)
388
+ else past_seen_tokens + sequence_length + 1
389
+ )
390
+
391
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
392
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
393
+ attention_mask,
394
+ sequence_length=sequence_length,
395
+ target_length=target_length,
396
+ dtype=dtype,
397
+ cache_position=cache_position,
398
+ batch_size=input_tensor.shape[0],
399
+ config=self.config,
400
+ past_key_values=past_key_values,
401
+ )
402
+
403
+ if (
404
+ self.config._attn_implementation == "sdpa"
405
+ and attention_mask is not None
406
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
407
+ and not output_attentions
408
+ ):
409
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
410
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
411
+ # Details: https://github.com/pytorch/pytorch/issues/110213
412
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
413
+
414
+ return causal_mask
415
+
416
+ @staticmethod
417
+ def _prepare_4d_causal_attention_mask_with_cache_position(
418
+ attention_mask: torch.Tensor,
419
+ sequence_length: int,
420
+ target_length: int,
421
+ dtype: torch.dtype,
422
+ cache_position: torch.Tensor,
423
+ batch_size: int,
424
+ config: HybridConfig,
425
+ past_key_values: Cache,
426
+ ):
427
+ """
428
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
429
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
430
+
431
+ Args:
432
+ attention_mask (`torch.Tensor`):
433
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
434
+ sequence_length (`int`):
435
+ The sequence length being processed.
436
+ target_length (`int`):
437
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
438
+ dtype (`torch.dtype`):
439
+ The dtype to use for the 4D attention mask.
440
+ cache_position (`torch.Tensor`):
441
+ Indices depicting the position of the input sequence tokens in the sequence.
442
+ batch_size (`torch.Tensor`):
443
+ Batch size.
444
+ config (`HybridConfig`):
445
+ The model's configuration class
446
+ past_key_values (`Cache`):
447
+ The cache class that is being used currently to generate
448
+ """
449
+ if attention_mask is not None and attention_mask.dim() == 4:
450
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
451
+ causal_mask = attention_mask
452
+ else:
453
+ min_dtype = torch.finfo(dtype).min
454
+ causal_mask = torch.full(
455
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
456
+ )
457
+ diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
458
+ -1, 1
459
+ )
460
+ text_config = config.get_text_config()
461
+ if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
462
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
463
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
464
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
465
+ sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
466
+ cache_position.reshape(-1, 1) - text_config.sliding_window
467
+ )
468
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
469
+ causal_mask *= diagonal_attend_mask
470
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
471
+ if attention_mask is not None:
472
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
473
+ if attention_mask.shape[-1] > target_length:
474
+ attention_mask = attention_mask[:, :target_length]
475
+ mask_length = attention_mask.shape[-1]
476
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
477
+ causal_mask.device
478
+ )
479
+ padding_mask = padding_mask == 0
480
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
481
+ padding_mask, min_dtype
482
+ )
483
+ return causal_mask
484
+
485
+
486
+ class KwargsForCausalLM(FlashAttentionKwargs): ...
487
+
488
+
489
+ # @auto_docstring
490
+ class HybridForCausalLM(HybridPreTrainedModel, GenerationMixin):
491
+ _tied_weights_keys = ["lm_head.weight"]
492
+ _tp_plan = {"lm_head": "colwise_rep"}
493
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
494
+
495
+ def __init__(self, config: HybridConfig):
496
+ super().__init__(config)
497
+ self.model = HybridModel(config)
498
+ self.vocab_size = config.vocab_size
499
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
500
+
501
+ self.use_cce = True
502
+ # Initialize weights and apply final processing
503
+ self.post_init()
504
+
505
+ def get_input_embeddings(self):
506
+ return self.model.embed_tokens
507
+
508
+ def set_input_embeddings(self, value):
509
+ self.model.embed_tokens = value
510
+
511
+ def get_output_embeddings(self):
512
+ return self.lm_head
513
+
514
+ def set_output_embeddings(self, new_embeddings):
515
+ self.lm_head = new_embeddings
516
+
517
+ def set_decoder(self, decoder):
518
+ self.model = decoder
519
+
520
+ def get_decoder(self):
521
+ return self.model
522
+
523
+ # @can_return_tuple
524
+ # @auto_docstring
525
+ def forward(
526
+ self,
527
+ input_ids: Optional[torch.LongTensor] = None,
528
+ attention_mask: Optional[torch.Tensor] = None,
529
+ position_ids: Optional[torch.LongTensor] = None,
530
+ past_key_values: Optional[Cache] = None,
531
+ inputs_embeds: Optional[torch.FloatTensor] = None,
532
+ labels: Optional[torch.LongTensor] = None,
533
+ use_cache: Optional[bool] = None,
534
+ output_attentions: Optional[bool] = None,
535
+ output_hidden_states: Optional[bool] = None,
536
+ cache_position: Optional[torch.LongTensor] = None,
537
+ logits_to_keep: Union[int, torch.Tensor] = 0,
538
+ return_logits: bool = False,
539
+ **kwargs: Unpack[KwargsForCausalLM],
540
+ ) -> CausalLMOutputWithPast:
541
+ r"""
542
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
543
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
544
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
545
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
546
+
547
+ Example:
548
+
549
+ ```python
550
+ >>> from transformers import AutoTokenizer, HybridForCausalLM
551
+
552
+ >>> model = HybridForCausalLM.from_pretrained("Qwen/Hybrid-8B")
553
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Hybrid-8B")
554
+
555
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
556
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
557
+
558
+ >>> # Generate
559
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
560
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
561
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
562
+ ```"""
563
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
564
+ output_hidden_states = (
565
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
566
+ )
567
+
568
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
569
+ outputs: BaseModelOutputWithPast = self.model(
570
+ input_ids=input_ids,
571
+ attention_mask=attention_mask,
572
+ position_ids=position_ids,
573
+ past_key_values=past_key_values,
574
+ inputs_embeds=inputs_embeds,
575
+ use_cache=use_cache,
576
+ output_attentions=output_attentions,
577
+ output_hidden_states=output_hidden_states,
578
+ cache_position=cache_position,
579
+ **kwargs,
580
+ )
581
+
582
+ hidden_states: Tensor = outputs.last_hidden_state
583
+ loss = None
584
+ logits = None
585
+ if return_logits or not self.training:
586
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
587
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
588
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
589
+
590
+ if labels is not None:
591
+ labels = labels.to(hidden_states.device)
592
+ if self.use_cce:
593
+ loss = linear_cross_entropy(
594
+ hidden_states,
595
+ self.lm_head.weight,
596
+ labels,
597
+ shift=True,
598
+ )
599
+ else:
600
+ logits = self.lm_head(hidden_states).to(torch.float32)
601
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
602
+
603
+ return CausalLMOutputWithPast(
604
+ loss=loss,
605
+ logits=logits,
606
+ past_key_values=outputs.past_key_values,
607
+ hidden_states=outputs.hidden_states,
608
+ attentions=outputs.attentions,
609
+ )
modeling_qwen3.py ADDED
@@ -0,0 +1,1045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/qwen3/modular_qwen3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_qwen3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+
22
+ from typing import Callable, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from torch import nn
26
+ import torch.nn.functional as F
27
+ from einops import einsum
28
+
29
+ from transformers.activations import ACT2FN
30
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.integrations import use_kernel_forward_from_hub
33
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
34
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
35
+ from transformers.modeling_layers import GradientCheckpointingLayer
36
+ from transformers.modeling_outputs import (
37
+ BaseModelOutputWithPast,
38
+ CausalLMOutputWithPast,
39
+ QuestionAnsweringModelOutput,
40
+ SequenceClassifierOutputWithPast,
41
+ TokenClassifierOutput,
42
+ )
43
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
+ from transformers.processing_utils import Unpack
46
+ from transformers.utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
+ # from .configuration_qwen3 import
48
+ from .configuration_hybrid import HybridConfig
49
+ from .cache import HybridCache
50
+
51
+
52
+ if is_torch_flex_attn_available():
53
+ from torch.nn.attention.flex_attention import BlockMask
54
+
55
+ from transformers.integrations.flex_attention import make_flex_block_causal_mask
56
+
57
+
58
+ logger = logging.get_logger(__name__)
59
+
60
+
61
+ @use_kernel_forward_from_hub("RMSNorm")
62
+ class Qwen3RMSNorm(nn.Module):
63
+ def __init__(self, hidden_size, eps=1e-6):
64
+ """
65
+ Qwen3RMSNorm is equivalent to T5LayerNorm
66
+ """
67
+ super().__init__()
68
+ self.weight = nn.Parameter(torch.ones(hidden_size))
69
+ self.variance_epsilon = eps
70
+
71
+ def forward(self, hidden_states):
72
+ input_dtype = hidden_states.dtype
73
+ hidden_states = hidden_states.to(torch.float32)
74
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
75
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
76
+ return self.weight * hidden_states.to(input_dtype)
77
+
78
+ def extra_repr(self):
79
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
80
+
81
+
82
+ class Qwen3MLP(nn.Module):
83
+ def __init__(self, config):
84
+ super().__init__()
85
+ self.config = config
86
+ self.hidden_size = config.hidden_size
87
+ self.intermediate_size = config.intermediate_size
88
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
89
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
90
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
91
+ self.act_fn = ACT2FN[config.hidden_act]
92
+
93
+ def forward(self, x):
94
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
95
+ return down_proj
96
+
97
+
98
+ def rotate_half(x):
99
+ """Rotates half the hidden dims of the input."""
100
+ x1 = x[..., : x.shape[-1] // 2]
101
+ x2 = x[..., x.shape[-1] // 2 :]
102
+ return torch.cat((-x2, x1), dim=-1)
103
+
104
+
105
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
106
+ """Applies Rotary Position Embedding to the query and key tensors.
107
+
108
+ Args:
109
+ q (`torch.Tensor`): The query tensor, assume (B, H, T, D) by default.
110
+ k (`torch.Tensor`): The key tensor, assume (B, H, T, D) by default.
111
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
112
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
113
+ position_ids (`torch.Tensor`, *optional*):
114
+ Deprecated and unused.
115
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
116
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
117
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
118
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
119
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
120
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
121
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
122
+ Returns:
123
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
124
+ """
125
+ cos = cos.unsqueeze(unsqueeze_dim)
126
+ sin = sin.unsqueeze(unsqueeze_dim)
127
+ q_embed = (q * cos) + (rotate_half(q) * sin)
128
+ k_embed = (k * cos) + (rotate_half(k) * sin)
129
+ return q_embed, k_embed
130
+
131
+
132
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
133
+ """
134
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
135
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
136
+ """
137
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
138
+ if n_rep == 1:
139
+ return hidden_states
140
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
141
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
142
+
143
+
144
+ def eager_attention_forward(
145
+ module: "Qwen3Attention",
146
+ query: torch.Tensor,
147
+ key: torch.Tensor,
148
+ value: torch.Tensor,
149
+ attention_mask: Optional[torch.Tensor],
150
+ scaling: float,
151
+ dropout: float = 0.0,
152
+ **kwargs,
153
+ ):
154
+ key_states = repeat_kv(key, module.num_key_value_groups)
155
+ value_states = repeat_kv(value, module.num_key_value_groups)
156
+
157
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
158
+ if attention_mask is not None:
159
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
160
+ attn_weights = attn_weights + causal_mask
161
+
162
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
163
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
164
+ attn_output = torch.matmul(attn_weights, value_states)
165
+ attn_output = attn_output.transpose(1, 2).contiguous()
166
+
167
+ return attn_output, attn_weights
168
+
169
+
170
+ class Qwen3Attention(nn.Module):
171
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
172
+
173
+ def __init__(self, config: HybridConfig, layer_idx: int):
174
+ super().__init__()
175
+ self.config = config
176
+ self.layer_idx = layer_idx
177
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
178
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
179
+ self.scaling = self.head_dim**-0.5
180
+ self.attention_dropout = config.attention_dropout
181
+ self.is_causal = True
182
+ self.use_output_gate = config.attn_use_output_gate
183
+
184
+ self.q_proj = nn.Linear(
185
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
186
+ )
187
+ self.k_proj = nn.Linear(
188
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
189
+ )
190
+ self.v_proj = nn.Linear(
191
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
192
+ )
193
+ self.o_proj = nn.Linear(
194
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
195
+ )
196
+ self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
197
+ self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape
198
+ self.sliding_window = config.sliding_window
199
+ if not (
200
+ self.config.use_sliding_window
201
+ and getattr(self.config, "sliding_window", None) is not None
202
+ and self.layer_idx >= self.config.max_window_layers
203
+ ):
204
+ self.sliding_window = None
205
+
206
+ if self.use_output_gate:
207
+ self.o_gate = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
208
+
209
+ def forward(
210
+ self,
211
+ hidden_states: torch.Tensor,
212
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
213
+ attention_mask: Optional[torch.Tensor],
214
+ past_key_values: Optional[HybridCache] = None,
215
+ cache_position: Optional[torch.LongTensor] = None,
216
+ position_ids: Optional[torch.LongTensor] = None,
217
+ **kwargs: Unpack[FlashAttentionKwargs],
218
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
219
+ input_shape = hidden_states.shape[:-1]
220
+ hidden_shape = (*input_shape, -1, self.head_dim)
221
+
222
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
223
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
224
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
225
+
226
+ if position_embeddings is not None:
227
+ cos, sin = position_embeddings
228
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
229
+ else:
230
+ cos, sin = None, None
231
+
232
+ if past_key_values is not None:
233
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
234
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
235
+ q_len = key_states.shape[-2]
236
+ attn_state = (key_states, value_states)
237
+ state = past_key_values.update(
238
+ attn_state=attn_state,
239
+ layer_idx=self.layer_idx,
240
+ offset=q_len,
241
+ )
242
+ key_states, value_states = state['attn_state']
243
+
244
+ attention_interface: Callable = eager_attention_forward
245
+ if self.config._attn_implementation != "eager":
246
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
247
+ logger.warning_once(
248
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
249
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
250
+ )
251
+ else:
252
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
253
+
254
+ # Logits scaling for length extrapolation
255
+ if self.config.attn_logits_scaling is not None:
256
+ query_states = query_states.transpose(1, 2) # (B, T, H, D)
257
+ if isinstance(self.config.attn_logits_scaling, float):
258
+ scale = self.config.attn_logits_scaling
259
+ query_states = query_states * scale
260
+ query_states = query_states.to(torch.bfloat16)
261
+ elif isinstance(self.config.attn_logits_scaling, str):
262
+ assert position_ids is not None, 'position_ids is required for attn_logits_scaling'
263
+ if len(self.config.attn_logits_scaling.split()) > 1:
264
+ a = float(self.config.attn_logits_scaling.split()[1])
265
+ else:
266
+ a = 362.0
267
+ # Create (B, T) tensor
268
+ scale = torch.log(position_ids + a) / torch.full_like(position_ids, fill_value=a).log() # (B, T)
269
+ query_states = einsum(query_states, scale, 'b t h d, b t -> b t h d')
270
+ query_states = query_states.to(torch.bfloat16)
271
+ else:
272
+ raise TypeError
273
+
274
+ query_states = query_states.transpose(1, 2) # (B, H, T, D)
275
+
276
+ o, attn_weights = attention_interface(
277
+ self,
278
+ query_states,
279
+ key_states,
280
+ value_states,
281
+ attention_mask,
282
+ dropout=0.0 if not self.training else self.attention_dropout,
283
+ scaling=self.scaling,
284
+ sliding_window=self.sliding_window, # diff with Llama
285
+ **kwargs,
286
+ )
287
+
288
+ o = o.reshape(*input_shape, -1).contiguous()
289
+ if self.use_output_gate:
290
+ o = o * F.sigmoid(self.o_gate(hidden_states))
291
+ y = self.o_proj(o)
292
+ return y, o, past_key_values
293
+
294
+
295
+ class Qwen3DecoderLayer(GradientCheckpointingLayer):
296
+ def __init__(self, config, layer_idx: int):
297
+ super().__init__()
298
+ self.hidden_size = config.hidden_size
299
+ self.self_attn = Qwen3Attention(config=config, layer_idx=layer_idx)
300
+ self.mlp = Qwen3MLP(config)
301
+ self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
302
+ self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303
+ if (
304
+ config.sliding_window and config._attn_implementation != "flash_attention_2"
305
+ ): # diff with Llama is this warning
306
+ logger.warning_once(
307
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
308
+ "unexpected results may be encountered."
309
+ )
310
+
311
+ def forward(
312
+ self,
313
+ hidden_states: torch.Tensor,
314
+ attention_mask: Optional[torch.Tensor] = None,
315
+ position_ids: Optional[torch.LongTensor] = None,
316
+ past_key_values: Optional[Cache] = None,
317
+ output_attentions: Optional[bool] = False,
318
+ use_cache: Optional[bool] = False,
319
+ cache_position: Optional[torch.LongTensor] = None,
320
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
321
+ **kwargs: Unpack[FlashAttentionKwargs],
322
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
323
+ residual = hidden_states
324
+ hidden_states = self.input_layernorm(hidden_states)
325
+
326
+ # Self Attention
327
+ hidden_states, self_attn_weights, _ = self.self_attn(
328
+ hidden_states=hidden_states,
329
+ attention_mask=attention_mask,
330
+ position_ids=position_ids,
331
+ past_key_values=past_key_values,
332
+ output_attentions=output_attentions,
333
+ use_cache=use_cache,
334
+ cache_position=cache_position,
335
+ position_embeddings=position_embeddings,
336
+ **kwargs,
337
+ )
338
+ hidden_states = residual + hidden_states
339
+
340
+ # Fully Connected
341
+ residual = hidden_states
342
+ hidden_states = self.post_attention_layernorm(hidden_states)
343
+ hidden_states = self.mlp(hidden_states)
344
+ hidden_states = residual + hidden_states
345
+
346
+ outputs = (hidden_states,)
347
+ if output_attentions:
348
+ outputs += (self_attn_weights,)
349
+
350
+ return outputs
351
+
352
+
353
+ @auto_docstring
354
+ class Qwen3PreTrainedModel(PreTrainedModel):
355
+ config_class = HybridConfig
356
+ base_model_prefix = "model"
357
+ supports_gradient_checkpointing = True
358
+ _no_split_modules = ["Qwen3DecoderLayer"]
359
+ _skip_keys_device_placement = ["past_key_values"]
360
+ _supports_flash_attn_2 = True
361
+ _supports_sdpa = True
362
+ _supports_flex_attn = True
363
+ _supports_cache_class = True
364
+ _supports_quantized_cache = True
365
+ _supports_static_cache = True
366
+ _supports_attention_backend = True
367
+
368
+ def _init_weights(self, module):
369
+ std = self.config.initializer_range
370
+ if isinstance(module, nn.Linear):
371
+ module.weight.data.normal_(mean=0.0, std=std)
372
+ if module.bias is not None:
373
+ module.bias.data.zero_()
374
+ elif isinstance(module, nn.Embedding):
375
+ module.weight.data.normal_(mean=0.0, std=std)
376
+ if module.padding_idx is not None:
377
+ module.weight.data[module.padding_idx].zero_()
378
+ elif isinstance(module, Qwen3RMSNorm):
379
+ module.weight.data.fill_(1.0)
380
+
381
+
382
+ class Qwen3RotaryEmbedding(nn.Module):
383
+ def __init__(self, config, device=None):
384
+ super().__init__()
385
+ # BC: "rope_type" was originally "type"
386
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
387
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
388
+ else:
389
+ self.rope_type = "default"
390
+ self.max_seq_len_cached = config.max_position_embeddings
391
+ self.original_max_seq_len = config.max_position_embeddings
392
+
393
+ self.config = config
394
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
395
+
396
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
397
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
398
+ self.original_inv_freq = self.inv_freq
399
+
400
+ @torch.no_grad()
401
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
402
+ def forward(self, x, position_ids):
403
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
404
+ position_ids_expanded = position_ids[:, None, :].float()
405
+
406
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
407
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
408
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
409
+ emb = torch.cat((freqs, freqs), dim=-1)
410
+ cos = emb.cos() * self.attention_scaling
411
+ sin = emb.sin() * self.attention_scaling
412
+
413
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
414
+
415
+
416
+ @auto_docstring
417
+ class Qwen3Model(Qwen3PreTrainedModel):
418
+ def __init__(self, config):
419
+ super().__init__(config)
420
+ self.padding_idx = config.pad_token_id
421
+ self.vocab_size = config.vocab_size
422
+
423
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
424
+ self.layers = nn.ModuleList(
425
+ [Qwen3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
426
+ )
427
+ self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
428
+ self.rotary_emb = Qwen3RotaryEmbedding(config=config)
429
+ self.gradient_checkpointing = False
430
+
431
+ # Initialize weights and apply final processing
432
+ self.post_init()
433
+
434
+ def get_input_embeddings(self):
435
+ return self.embed_tokens
436
+
437
+ def set_input_embeddings(self, value):
438
+ self.embed_tokens = value
439
+
440
+ @can_return_tuple
441
+ @auto_docstring
442
+ def forward(
443
+ self,
444
+ input_ids: Optional[torch.LongTensor] = None,
445
+ attention_mask: Optional[torch.Tensor] = None,
446
+ position_ids: Optional[torch.LongTensor] = None,
447
+ past_key_values: Optional[Cache] = None,
448
+ inputs_embeds: Optional[torch.FloatTensor] = None,
449
+ use_cache: Optional[bool] = None,
450
+ output_attentions: Optional[bool] = None,
451
+ output_hidden_states: Optional[bool] = None,
452
+ cache_position: Optional[torch.LongTensor] = None,
453
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
454
+ ) -> BaseModelOutputWithPast:
455
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
456
+ output_hidden_states = (
457
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
458
+ )
459
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
460
+
461
+ if (input_ids is None) ^ (inputs_embeds is not None):
462
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
463
+
464
+ if self.gradient_checkpointing and self.training and use_cache:
465
+ logger.warning_once(
466
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
467
+ )
468
+ use_cache = False
469
+
470
+ # TODO (joao): remove this exception in v4.56 -- it exists for users that try to pass a legacy cache
471
+ if not isinstance(past_key_values, (type(None), Cache)):
472
+ raise ValueError("The `past_key_values` should be either a `Cache` object or `None`.")
473
+
474
+ if inputs_embeds is None:
475
+ inputs_embeds = self.embed_tokens(input_ids)
476
+
477
+ if use_cache and past_key_values is None:
478
+ past_key_values = DynamicCache()
479
+
480
+ if cache_position is None:
481
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
482
+ cache_position = torch.arange(
483
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
484
+ )
485
+
486
+ if position_ids is None:
487
+ position_ids = cache_position.unsqueeze(0)
488
+
489
+ causal_mask = self._update_causal_mask(
490
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
491
+ )
492
+
493
+ hidden_states = inputs_embeds
494
+
495
+ # create position embeddings to be shared across the decoder layers
496
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
497
+
498
+ # decoder layers
499
+ all_hidden_states = () if output_hidden_states else None
500
+ all_self_attns = () if output_attentions else None
501
+
502
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
503
+ if output_hidden_states:
504
+ all_hidden_states += (hidden_states,)
505
+
506
+ layer_outputs = decoder_layer(
507
+ hidden_states,
508
+ attention_mask=causal_mask,
509
+ position_ids=position_ids,
510
+ past_key_values=past_key_values,
511
+ output_attentions=output_attentions,
512
+ use_cache=use_cache,
513
+ cache_position=cache_position,
514
+ position_embeddings=position_embeddings,
515
+ **flash_attn_kwargs,
516
+ )
517
+
518
+ hidden_states = layer_outputs[0]
519
+
520
+ if output_attentions:
521
+ all_self_attns += (layer_outputs[1],)
522
+
523
+ hidden_states = self.norm(hidden_states)
524
+
525
+ # add hidden states from the last decoder layer
526
+ if output_hidden_states:
527
+ all_hidden_states += (hidden_states,)
528
+
529
+ return BaseModelOutputWithPast(
530
+ last_hidden_state=hidden_states,
531
+ past_key_values=past_key_values if use_cache else None,
532
+ hidden_states=all_hidden_states,
533
+ attentions=all_self_attns,
534
+ )
535
+
536
+ def _update_causal_mask(
537
+ self,
538
+ attention_mask: Union[torch.Tensor, "BlockMask"],
539
+ input_tensor: torch.Tensor,
540
+ cache_position: torch.Tensor,
541
+ past_key_values: Cache,
542
+ output_attentions: bool = False,
543
+ ):
544
+ if self.config._attn_implementation == "flash_attention_2":
545
+ if attention_mask is not None and past_key_values is not None:
546
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
547
+ if is_padding_right:
548
+ raise ValueError(
549
+ "You are attempting to perform batched generation with padding_side='right'"
550
+ " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to "
551
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
552
+ )
553
+ if attention_mask is not None and 0.0 in attention_mask:
554
+ return attention_mask
555
+ return None
556
+ if self.config._attn_implementation == "flex_attention":
557
+ if isinstance(attention_mask, torch.Tensor):
558
+ attention_mask = make_flex_block_causal_mask(attention_mask)
559
+ return attention_mask
560
+
561
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
562
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
563
+ # to infer the attention mask.
564
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
565
+ using_static_cache = isinstance(past_key_values, StaticCache)
566
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
567
+
568
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
569
+ if (
570
+ self.config._attn_implementation == "sdpa"
571
+ and not (using_static_cache or using_sliding_window_cache)
572
+ and not output_attentions
573
+ ):
574
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
575
+ attention_mask,
576
+ inputs_embeds=input_tensor,
577
+ past_key_values_length=past_seen_tokens,
578
+ sliding_window=self.config.sliding_window,
579
+ is_training=self.training,
580
+ ):
581
+ return None
582
+
583
+ dtype = input_tensor.dtype
584
+ min_dtype = torch.finfo(dtype).min
585
+ sequence_length = input_tensor.shape[1]
586
+ # SlidingWindowCache or StaticCache
587
+ if using_sliding_window_cache or using_static_cache:
588
+ target_length = past_key_values.get_max_cache_shape()
589
+ # DynamicCache or no cache
590
+ else:
591
+ target_length = (
592
+ attention_mask.shape[-1]
593
+ if isinstance(attention_mask, torch.Tensor)
594
+ else past_seen_tokens + sequence_length + 1
595
+ )
596
+
597
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
598
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
599
+ attention_mask,
600
+ sequence_length=sequence_length,
601
+ target_length=target_length,
602
+ dtype=dtype,
603
+ cache_position=cache_position,
604
+ batch_size=input_tensor.shape[0],
605
+ config=self.config,
606
+ past_key_values=past_key_values,
607
+ )
608
+
609
+ if (
610
+ self.config._attn_implementation == "sdpa"
611
+ and attention_mask is not None
612
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
613
+ and not output_attentions
614
+ ):
615
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
616
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
617
+ # Details: https://github.com/pytorch/pytorch/issues/110213
618
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
619
+
620
+ return causal_mask
621
+
622
+ @staticmethod
623
+ def _prepare_4d_causal_attention_mask_with_cache_position(
624
+ attention_mask: torch.Tensor,
625
+ sequence_length: int,
626
+ target_length: int,
627
+ dtype: torch.dtype,
628
+ cache_position: torch.Tensor,
629
+ batch_size: int,
630
+ config,
631
+ past_key_values: Cache,
632
+ ):
633
+ """
634
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
635
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
636
+
637
+ Args:
638
+ attention_mask (`torch.Tensor`):
639
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
640
+ sequence_length (`int`):
641
+ The sequence length being processed.
642
+ target_length (`int`):
643
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
644
+ dtype (`torch.dtype`):
645
+ The dtype to use for the 4D attention mask.
646
+ cache_position (`torch.Tensor`):
647
+ Indices depicting the position of the input sequence tokens in the sequence.
648
+ batch_size (`torch.Tensor`):
649
+ Batch size.
650
+ config (`Qwen3Config`):
651
+ The model's configuration class
652
+ past_key_values (`Cache`):
653
+ The cache class that is being used currently to generate
654
+ """
655
+ if attention_mask is not None and attention_mask.dim() == 4:
656
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
657
+ causal_mask = attention_mask
658
+ else:
659
+ min_dtype = torch.finfo(dtype).min
660
+ causal_mask = torch.full(
661
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
662
+ )
663
+ diagonal_attend_mask = torch.arange(target_length, device=cache_position.device) > cache_position.reshape(
664
+ -1, 1
665
+ )
666
+ text_config = config.get_text_config()
667
+ if getattr(text_config, "use_sliding_window", True) and text_config.sliding_window is not None:
668
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
669
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
670
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
671
+ sliding_attend_mask = torch.arange(target_length, device=cache_position.device) <= (
672
+ cache_position.reshape(-1, 1) - text_config.sliding_window
673
+ )
674
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
675
+ causal_mask *= diagonal_attend_mask
676
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
677
+ if attention_mask is not None:
678
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
679
+ if attention_mask.shape[-1] > target_length:
680
+ attention_mask = attention_mask[:, :target_length]
681
+ mask_length = attention_mask.shape[-1]
682
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
683
+ causal_mask.device
684
+ )
685
+ padding_mask = padding_mask == 0
686
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
687
+ padding_mask, min_dtype
688
+ )
689
+ return causal_mask
690
+
691
+
692
+ class KwargsForCausalLM(FlashAttentionKwargs): ...
693
+
694
+
695
+ @auto_docstring
696
+ class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
697
+ _tied_weights_keys = ["lm_head.weight"]
698
+ _tp_plan = {"lm_head": "colwise_rep"}
699
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
700
+
701
+ def __init__(self, config):
702
+ super().__init__(config)
703
+ self.model = Qwen3Model(config)
704
+ self.vocab_size = config.vocab_size
705
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
706
+
707
+ # Initialize weights and apply final processing
708
+ self.post_init()
709
+
710
+ def get_input_embeddings(self):
711
+ return self.model.embed_tokens
712
+
713
+ def set_input_embeddings(self, value):
714
+ self.model.embed_tokens = value
715
+
716
+ def get_output_embeddings(self):
717
+ return self.lm_head
718
+
719
+ def set_output_embeddings(self, new_embeddings):
720
+ self.lm_head = new_embeddings
721
+
722
+ def set_decoder(self, decoder):
723
+ self.model = decoder
724
+
725
+ def get_decoder(self):
726
+ return self.model
727
+
728
+ @can_return_tuple
729
+ @auto_docstring
730
+ def forward(
731
+ self,
732
+ input_ids: Optional[torch.LongTensor] = None,
733
+ attention_mask: Optional[torch.Tensor] = None,
734
+ position_ids: Optional[torch.LongTensor] = None,
735
+ past_key_values: Optional[Cache] = None,
736
+ inputs_embeds: Optional[torch.FloatTensor] = None,
737
+ labels: Optional[torch.LongTensor] = None,
738
+ use_cache: Optional[bool] = None,
739
+ output_attentions: Optional[bool] = None,
740
+ output_hidden_states: Optional[bool] = None,
741
+ cache_position: Optional[torch.LongTensor] = None,
742
+ logits_to_keep: Union[int, torch.Tensor] = 0,
743
+ **kwargs: Unpack[KwargsForCausalLM],
744
+ ) -> CausalLMOutputWithPast:
745
+ r"""
746
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
747
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
748
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
749
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
750
+
751
+ Example:
752
+
753
+ ```python
754
+ >>> from transformers import AutoTokenizer, Qwen3ForCausalLM
755
+
756
+ >>> model = Qwen3ForCausalLM.from_pretrained("Qwen/Qwen3-8B")
757
+ >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
758
+
759
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
760
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
761
+
762
+ >>> # Generate
763
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
764
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
765
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
766
+ ```"""
767
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
768
+ output_hidden_states = (
769
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
770
+ )
771
+
772
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
773
+ outputs: BaseModelOutputWithPast = self.model(
774
+ input_ids=input_ids,
775
+ attention_mask=attention_mask,
776
+ position_ids=position_ids,
777
+ past_key_values=past_key_values,
778
+ inputs_embeds=inputs_embeds,
779
+ use_cache=use_cache,
780
+ output_attentions=output_attentions,
781
+ output_hidden_states=output_hidden_states,
782
+ cache_position=cache_position,
783
+ **kwargs,
784
+ )
785
+
786
+ hidden_states = outputs.last_hidden_state
787
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
788
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
789
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
790
+
791
+ loss = None
792
+ if labels is not None:
793
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
794
+
795
+ return CausalLMOutputWithPast(
796
+ loss=loss,
797
+ logits=logits,
798
+ past_key_values=outputs.past_key_values,
799
+ hidden_states=outputs.hidden_states,
800
+ attentions=outputs.attentions,
801
+ )
802
+
803
+
804
+ @auto_docstring(
805
+ custom_intro="""
806
+ The Qwen3 Model transformer with a sequence classification head on top (linear layer).
807
+
808
+ [`Qwen3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
809
+ (e.g. GPT-2) do.
810
+
811
+ Since it does classification on the last token, it requires to know the position of the last token. If a
812
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
813
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
814
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
815
+ each row of the batch).
816
+ """
817
+ )
818
+ class Qwen3ForSequenceClassification(Qwen3PreTrainedModel):
819
+ def __init__(self, config):
820
+ super().__init__(config)
821
+ self.num_labels = config.num_labels
822
+ self.model = Qwen3Model(config)
823
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
824
+
825
+ # Initialize weights and apply final processing
826
+ self.post_init()
827
+
828
+ def get_input_embeddings(self):
829
+ return self.model.embed_tokens
830
+
831
+ def set_input_embeddings(self, value):
832
+ self.model.embed_tokens = value
833
+
834
+ @can_return_tuple
835
+ @auto_docstring
836
+ def forward(
837
+ self,
838
+ input_ids: Optional[torch.LongTensor] = None,
839
+ attention_mask: Optional[torch.Tensor] = None,
840
+ position_ids: Optional[torch.LongTensor] = None,
841
+ past_key_values: Optional[Cache] = None,
842
+ inputs_embeds: Optional[torch.FloatTensor] = None,
843
+ labels: Optional[torch.LongTensor] = None,
844
+ use_cache: Optional[bool] = None,
845
+ output_attentions: Optional[bool] = None,
846
+ output_hidden_states: Optional[bool] = None,
847
+ ) -> SequenceClassifierOutputWithPast:
848
+ r"""
849
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
850
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
851
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
852
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
853
+ """
854
+
855
+ transformer_outputs: BaseModelOutputWithPast = self.model(
856
+ input_ids,
857
+ attention_mask=attention_mask,
858
+ position_ids=position_ids,
859
+ past_key_values=past_key_values,
860
+ inputs_embeds=inputs_embeds,
861
+ use_cache=use_cache,
862
+ output_attentions=output_attentions,
863
+ output_hidden_states=output_hidden_states,
864
+ )
865
+ hidden_states = transformer_outputs.last_hidden_state
866
+ logits = self.score(hidden_states)
867
+
868
+ if input_ids is not None:
869
+ batch_size = input_ids.shape[0]
870
+ else:
871
+ batch_size = inputs_embeds.shape[0]
872
+
873
+ if self.config.pad_token_id is None and batch_size != 1:
874
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
875
+ if self.config.pad_token_id is None:
876
+ last_non_pad_token = -1
877
+ elif input_ids is not None:
878
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
879
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
880
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
881
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
882
+ else:
883
+ last_non_pad_token = -1
884
+ logger.warning_once(
885
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
886
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
887
+ )
888
+
889
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
890
+
891
+ loss = None
892
+ if labels is not None:
893
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
894
+
895
+ return SequenceClassifierOutputWithPast(
896
+ loss=loss,
897
+ logits=pooled_logits,
898
+ past_key_values=transformer_outputs.past_key_values,
899
+ hidden_states=transformer_outputs.hidden_states,
900
+ attentions=transformer_outputs.attentions,
901
+ )
902
+
903
+
904
+ @auto_docstring
905
+ class Qwen3ForTokenClassification(Qwen3PreTrainedModel):
906
+ def __init__(self, config):
907
+ super().__init__(config)
908
+ self.num_labels = config.num_labels
909
+ self.model = Qwen3Model(config)
910
+ if getattr(config, "classifier_dropout", None) is not None:
911
+ classifier_dropout = config.classifier_dropout
912
+ elif getattr(config, "hidden_dropout", None) is not None:
913
+ classifier_dropout = config.hidden_dropout
914
+ else:
915
+ classifier_dropout = 0.1
916
+ self.dropout = nn.Dropout(classifier_dropout)
917
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
918
+
919
+ # Initialize weights and apply final processing
920
+ self.post_init()
921
+
922
+ def get_input_embeddings(self):
923
+ return self.model.embed_tokens
924
+
925
+ def set_input_embeddings(self, value):
926
+ self.model.embed_tokens = value
927
+
928
+ @can_return_tuple
929
+ @auto_docstring
930
+ def forward(
931
+ self,
932
+ input_ids: Optional[torch.LongTensor] = None,
933
+ attention_mask: Optional[torch.Tensor] = None,
934
+ position_ids: Optional[torch.LongTensor] = None,
935
+ past_key_values: Optional[Cache] = None,
936
+ inputs_embeds: Optional[torch.FloatTensor] = None,
937
+ labels: Optional[torch.LongTensor] = None,
938
+ use_cache: Optional[bool] = None,
939
+ output_attentions: Optional[bool] = None,
940
+ output_hidden_states: Optional[bool] = None,
941
+ ) -> TokenClassifierOutput:
942
+ r"""
943
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
944
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
945
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
946
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
947
+ """
948
+
949
+ outputs: BaseModelOutputWithPast = self.model(
950
+ input_ids,
951
+ attention_mask=attention_mask,
952
+ position_ids=position_ids,
953
+ past_key_values=past_key_values,
954
+ inputs_embeds=inputs_embeds,
955
+ use_cache=use_cache,
956
+ output_attentions=output_attentions,
957
+ output_hidden_states=output_hidden_states,
958
+ )
959
+ sequence_output = outputs.last_hidden_state
960
+ sequence_output = self.dropout(sequence_output)
961
+ logits = self.score(sequence_output)
962
+
963
+ loss = None
964
+ if labels is not None:
965
+ loss = self.loss_function(logits, labels, self.config)
966
+
967
+ return TokenClassifierOutput(
968
+ loss=loss,
969
+ logits=logits,
970
+ hidden_states=outputs.hidden_states,
971
+ attentions=outputs.attentions,
972
+ )
973
+
974
+
975
+ @auto_docstring
976
+ class Qwen3ForQuestionAnswering(Qwen3PreTrainedModel):
977
+ base_model_prefix = "transformer"
978
+
979
+ def __init__(self, config):
980
+ super().__init__(config)
981
+ self.transformer = Qwen3Model(config)
982
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
983
+
984
+ # Initialize weights and apply final processing
985
+ self.post_init()
986
+
987
+ def get_input_embeddings(self):
988
+ return self.transformer.embed_tokens
989
+
990
+ def set_input_embeddings(self, value):
991
+ self.transformer.embed_tokens = value
992
+
993
+ @can_return_tuple
994
+ @auto_docstring
995
+ def forward(
996
+ self,
997
+ input_ids: Optional[torch.LongTensor] = None,
998
+ attention_mask: Optional[torch.Tensor] = None,
999
+ position_ids: Optional[torch.LongTensor] = None,
1000
+ past_key_values: Optional[Cache] = None,
1001
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1002
+ start_positions: Optional[torch.LongTensor] = None,
1003
+ end_positions: Optional[torch.LongTensor] = None,
1004
+ output_attentions: Optional[bool] = None,
1005
+ output_hidden_states: Optional[bool] = None,
1006
+ **kwargs,
1007
+ ) -> QuestionAnsweringModelOutput:
1008
+ outputs: BaseModelOutputWithPast = self.transformer(
1009
+ input_ids,
1010
+ attention_mask=attention_mask,
1011
+ position_ids=position_ids,
1012
+ past_key_values=past_key_values,
1013
+ inputs_embeds=inputs_embeds,
1014
+ output_attentions=output_attentions,
1015
+ output_hidden_states=output_hidden_states,
1016
+ )
1017
+
1018
+ sequence_output = outputs.last_hidden_state
1019
+
1020
+ logits = self.qa_outputs(sequence_output)
1021
+ start_logits, end_logits = logits.split(1, dim=-1)
1022
+ start_logits = start_logits.squeeze(-1).contiguous()
1023
+ end_logits = end_logits.squeeze(-1).contiguous()
1024
+
1025
+ loss = None
1026
+ if start_positions is not None and end_positions is not None:
1027
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1028
+
1029
+ return QuestionAnsweringModelOutput(
1030
+ loss=loss,
1031
+ start_logits=start_logits,
1032
+ end_logits=end_logits,
1033
+ hidden_states=outputs.hidden_states,
1034
+ attentions=outputs.attentions,
1035
+ )
1036
+
1037
+
1038
+ __all__ = [
1039
+ "Qwen3ForCausalLM",
1040
+ "Qwen3ForQuestionAnswering",
1041
+ "Qwen3Model",
1042
+ "Qwen3PreTrainedModel",
1043
+ "Qwen3ForSequenceClassification",
1044
+ "Qwen3ForTokenClassification",
1045
+ ]
special_tokens_map.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|im_start|>",
4
+ "<|im_end|>",
5
+ "<|object_ref_start|>",
6
+ "<|object_ref_end|>",
7
+ "<|box_start|>",
8
+ "<|box_end|>",
9
+ "<|quad_start|>",
10
+ "<|quad_end|>",
11
+ "<|vision_start|>",
12
+ "<|vision_end|>",
13
+ "<|vision_pad|>",
14
+ "<|image_pad|>",
15
+ "<|video_pad|>"
16
+ ],
17
+ "eos_token": {
18
+ "content": "<|im_end|>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ },
24
+ "pad_token": {
25
+ "content": "<|endoftext|>",
26
+ "lstrip": false,
27
+ "normalized": false,
28
+ "rstrip": false,
29
+ "single_word": false
30
+ }
31
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "clean_up_tokenization_spaces": false,
231
+ "eos_token": "<|im_end|>",
232
+ "errors": "replace",
233
+ "extra_special_tokens": {},
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff