cchudant commited on
Commit
d5ff350
1 Parent(s): 88b92ae

Fix the kv-cache dimensions

Browse files

Hello!
I have noticed that the dimension of the kv-cache here is weird, and does not match the hugginface transformers modeling_bloom.py file.
Is the departure from the bloom dimension intended?
Judging from the copy-pasted comments, it looks like a bug - also, `_convert_to_rw_cache` & its `_convert_to_standard_cache` counterpart matches bloom dimensions.

Files changed (1) hide show
  1. modelling_RW.py +1 -1
modelling_RW.py CHANGED
@@ -271,7 +271,7 @@ class Attention(nn.Module):
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
274
- key_layer = torch.cat((past_key, key_layer), dim=1)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape
 
271
  # concatenate along seq_length dimension:
272
  # - key: [batch_size * self.num_heads, head_dim, kv_length]
273
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
274
+ key_layer = torch.cat((past_key, key_layer), dim=2)
275
  value_layer = torch.cat((past_value, value_layer), dim=1)
276
 
277
  _, kv_length, _ = key_layer.shape