Update modeling_quiet.py
Browse files- modeling_quiet.py +19 -31
modeling_quiet.py
CHANGED
@@ -270,22 +270,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
270 |
|
271 |
|
272 |
class QuietAttention(nn.Module):
|
273 |
-
"""
|
274 |
-
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
275 |
-
and "Generating Long Sequences with Sparse Transformers".
|
276 |
-
"""
|
277 |
-
|
278 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
279 |
super().__init__()
|
280 |
self.config = config
|
281 |
self.layer_idx = layer_idx
|
282 |
if layer_idx is None:
|
283 |
logger.warning_once(
|
284 |
-
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
285 |
-
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
286 |
-
"when creating this class."
|
287 |
)
|
288 |
-
|
289 |
self.hidden_size = config.hidden_size
|
290 |
self.num_heads = config.num_attention_heads
|
291 |
self.head_dim = self.hidden_size // self.num_heads
|
@@ -296,20 +288,17 @@ class QuietAttention(nn.Module):
|
|
296 |
self.is_causal = True
|
297 |
self.attention_dropout = config.attention_dropout
|
298 |
|
299 |
-
if
|
300 |
raise ValueError(
|
301 |
-
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
302 |
-
f" and `num_heads`: {self.num_heads})."
|
303 |
)
|
|
|
304 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
305 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
306 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
307 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
308 |
-
|
309 |
self.rotary_emb = QuietRotaryEmbedding(
|
310 |
-
self.head_dim,
|
311 |
-
max_position_embeddings=self.max_position_embeddings,
|
312 |
-
base=self.rope_theta,
|
313 |
)
|
314 |
|
315 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
@@ -327,8 +316,9 @@ class QuietAttention(nn.Module):
|
|
327 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
328 |
if "padding_mask" in kwargs:
|
329 |
warnings.warn(
|
330 |
-
"
|
331 |
)
|
|
|
332 |
bsz, q_len, _ = hidden_states.size()
|
333 |
|
334 |
query_states = self.q_proj(hidden_states)
|
@@ -343,19 +333,17 @@ class QuietAttention(nn.Module):
|
|
343 |
if past_key_value is not None:
|
344 |
if self.layer_idx is None:
|
345 |
raise ValueError(
|
346 |
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
347 |
-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
348 |
-
"with a layer index."
|
349 |
)
|
350 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
351 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
352 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
353 |
|
354 |
if past_key_value is not None:
|
355 |
-
cache_kwargs = {"sin": sin, "cos": cos} #
|
356 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
357 |
|
358 |
-
# repeat k/v heads if n_kv_heads < n_heads
|
359 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
360 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
361 |
|
@@ -363,32 +351,32 @@ class QuietAttention(nn.Module):
|
|
363 |
|
364 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
365 |
raise ValueError(
|
366 |
-
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
367 |
-
f" {attn_weights.size()}"
|
368 |
)
|
369 |
|
370 |
if attention_mask is not None:
|
|
|
|
|
|
|
|
|
|
|
371 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
372 |
raise ValueError(
|
373 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
374 |
)
|
375 |
-
|
376 |
attn_weights = attn_weights + attention_mask
|
377 |
|
378 |
-
# upcast attention to fp32
|
379 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
380 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
|
|
381 |
attn_output = torch.matmul(attn_weights, value_states)
|
382 |
|
383 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
384 |
raise ValueError(
|
385 |
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
386 |
-
f" {attn_output.size()}"
|
387 |
)
|
388 |
|
389 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
390 |
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
391 |
-
|
392 |
attn_output = self.o_proj(attn_output)
|
393 |
|
394 |
if not output_attentions:
|
|
|
270 |
|
271 |
|
272 |
class QuietAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
273 |
def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
|
274 |
super().__init__()
|
275 |
self.config = config
|
276 |
self.layer_idx = layer_idx
|
277 |
if layer_idx is None:
|
278 |
logger.warning_once(
|
279 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class."
|
|
|
|
|
280 |
)
|
|
|
281 |
self.hidden_size = config.hidden_size
|
282 |
self.num_heads = config.num_attention_heads
|
283 |
self.head_dim = self.hidden_size // self.num_heads
|
|
|
288 |
self.is_causal = True
|
289 |
self.attention_dropout = config.attention_dropout
|
290 |
|
291 |
+
if self.head_dim * self.num_heads != self.hidden_size:
|
292 |
raise ValueError(
|
293 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})."
|
|
|
294 |
)
|
295 |
+
|
296 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
297 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
298 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
299 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
|
|
300 |
self.rotary_emb = QuietRotaryEmbedding(
|
301 |
+
self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta
|
|
|
|
|
302 |
)
|
303 |
|
304 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
316 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
317 |
if "padding_mask" in kwargs:
|
318 |
warnings.warn(
|
319 |
+
"`padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
320 |
)
|
321 |
+
|
322 |
bsz, q_len, _ = hidden_states.size()
|
323 |
|
324 |
query_states = self.q_proj(hidden_states)
|
|
|
333 |
if past_key_value is not None:
|
334 |
if self.layer_idx is None:
|
335 |
raise ValueError(
|
336 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to initialize the attention class with a layer index."
|
|
|
|
|
337 |
)
|
338 |
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
339 |
+
|
340 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
341 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
342 |
|
343 |
if past_key_value is not None:
|
344 |
+
cache_kwargs = {"sin": sin, "cos": cos} # required by original DynamicCache.update() function
|
345 |
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
346 |
|
|
|
347 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
348 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
349 |
|
|
|
351 |
|
352 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
353 |
raise ValueError(
|
354 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is {attn_weights.size()}"
|
|
|
355 |
)
|
356 |
|
357 |
if attention_mask is not None:
|
358 |
+
if attention_mask.dim() == 3:
|
359 |
+
attention_mask = attention_mask.unsqueeze(1)
|
360 |
+
elif attention_mask.dim() == 2:
|
361 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
362 |
+
|
363 |
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
364 |
raise ValueError(
|
365 |
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
366 |
)
|
|
|
367 |
attn_weights = attn_weights + attention_mask
|
368 |
|
|
|
369 |
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
370 |
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
371 |
+
|
372 |
attn_output = torch.matmul(attn_weights, value_states)
|
373 |
|
374 |
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
375 |
raise ValueError(
|
376 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}"
|
|
|
377 |
)
|
378 |
|
379 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
|
|
|
|
|
380 |
attn_output = self.o_proj(attn_output)
|
381 |
|
382 |
if not output_attentions:
|