Crystalcareai commited on
Commit
6dcf37d
·
verified ·
1 Parent(s): 1f8e662

Update modeling_quiet.py

Browse files
Files changed (1) hide show
  1. modeling_quiet.py +19 -31
modeling_quiet.py CHANGED
@@ -270,22 +270,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
270
 
271
 
272
  class QuietAttention(nn.Module):
273
- """
274
- Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
275
- and "Generating Long Sequences with Sparse Transformers".
276
- """
277
-
278
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
279
  super().__init__()
280
  self.config = config
281
  self.layer_idx = layer_idx
282
  if layer_idx is None:
283
  logger.warning_once(
284
- f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
285
- "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
286
- "when creating this class."
287
  )
288
-
289
  self.hidden_size = config.hidden_size
290
  self.num_heads = config.num_attention_heads
291
  self.head_dim = self.hidden_size // self.num_heads
@@ -296,20 +288,17 @@ class QuietAttention(nn.Module):
296
  self.is_causal = True
297
  self.attention_dropout = config.attention_dropout
298
 
299
- if (self.head_dim * self.num_heads) != self.hidden_size:
300
  raise ValueError(
301
- f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
302
- f" and `num_heads`: {self.num_heads})."
303
  )
 
304
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
305
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
306
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
307
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
308
-
309
  self.rotary_emb = QuietRotaryEmbedding(
310
- self.head_dim,
311
- max_position_embeddings=self.max_position_embeddings,
312
- base=self.rope_theta,
313
  )
314
 
315
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -327,8 +316,9 @@ class QuietAttention(nn.Module):
327
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
328
  if "padding_mask" in kwargs:
329
  warnings.warn(
330
- "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
331
  )
 
332
  bsz, q_len, _ = hidden_states.size()
333
 
334
  query_states = self.q_proj(hidden_states)
@@ -343,19 +333,17 @@ class QuietAttention(nn.Module):
343
  if past_key_value is not None:
344
  if self.layer_idx is None:
345
  raise ValueError(
346
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
347
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
348
- "with a layer index."
349
  )
350
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
351
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
352
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
353
 
354
  if past_key_value is not None:
355
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
356
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
357
 
358
- # repeat k/v heads if n_kv_heads < n_heads
359
  key_states = repeat_kv(key_states, self.num_key_value_groups)
360
  value_states = repeat_kv(value_states, self.num_key_value_groups)
361
 
@@ -363,32 +351,32 @@ class QuietAttention(nn.Module):
363
 
364
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
365
  raise ValueError(
366
- f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
367
- f" {attn_weights.size()}"
368
  )
369
 
370
  if attention_mask is not None:
 
 
 
 
 
371
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
372
  raise ValueError(
373
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
374
  )
375
-
376
  attn_weights = attn_weights + attention_mask
377
 
378
- # upcast attention to fp32
379
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
380
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
 
381
  attn_output = torch.matmul(attn_weights, value_states)
382
 
383
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
384
  raise ValueError(
385
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
386
- f" {attn_output.size()}"
387
  )
388
 
389
- attn_output = attn_output.transpose(1, 2).contiguous()
390
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
391
-
392
  attn_output = self.o_proj(attn_output)
393
 
394
  if not output_attentions:
 
270
 
271
 
272
  class QuietAttention(nn.Module):
 
 
 
 
 
273
  def __init__(self, config: QuietConfig, layer_idx: Optional[int] = None):
274
  super().__init__()
275
  self.config = config
276
  self.layer_idx = layer_idx
277
  if layer_idx is None:
278
  logger.warning_once(
279
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` when creating this class."
 
 
280
  )
 
281
  self.hidden_size = config.hidden_size
282
  self.num_heads = config.num_attention_heads
283
  self.head_dim = self.hidden_size // self.num_heads
 
288
  self.is_causal = True
289
  self.attention_dropout = config.attention_dropout
290
 
291
+ if self.head_dim * self.num_heads != self.hidden_size:
292
  raise ValueError(
293
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`: {self.num_heads})."
 
294
  )
295
+
296
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
297
  self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
298
  self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
299
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
300
  self.rotary_emb = QuietRotaryEmbedding(
301
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta
 
 
302
  )
303
 
304
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
316
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
317
  if "padding_mask" in kwargs:
318
  warnings.warn(
319
+ "`padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
320
  )
321
+
322
  bsz, q_len, _ = hidden_states.size()
323
 
324
  query_states = self.q_proj(hidden_states)
 
333
  if past_key_value is not None:
334
  if self.layer_idx is None:
335
  raise ValueError(
336
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} for auto-regressive decoding with k/v caching, please make sure to initialize the attention class with a layer index."
 
 
337
  )
338
  kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
339
+
340
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
341
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
342
 
343
  if past_key_value is not None:
344
+ cache_kwargs = {"sin": sin, "cos": cos} # required by original DynamicCache.update() function
345
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
346
 
 
347
  key_states = repeat_kv(key_states, self.num_key_value_groups)
348
  value_states = repeat_kv(value_states, self.num_key_value_groups)
349
 
 
351
 
352
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
353
  raise ValueError(
354
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is {attn_weights.size()}"
 
355
  )
356
 
357
  if attention_mask is not None:
358
+ if attention_mask.dim() == 3:
359
+ attention_mask = attention_mask.unsqueeze(1)
360
+ elif attention_mask.dim() == 2:
361
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
362
+
363
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
364
  raise ValueError(
365
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
366
  )
 
367
  attn_weights = attn_weights + attention_mask
368
 
 
369
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
370
  attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
371
+
372
  attn_output = torch.matmul(attn_weights, value_states)
373
 
374
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
375
  raise ValueError(
376
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is {attn_output.size()}"
 
377
  )
378
 
379
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
 
 
380
  attn_output = self.o_proj(attn_output)
381
 
382
  if not output_attentions: