ssmi153 commited on
Commit
10405b9
·
unverified ·
1 Parent(s): c93655c

Update XFormers Attention Monkeypatch to handle Llama-2 70B (GQA) (#339)

Browse files

* Fix XFormers attention for Llama-2 70B (GQA)

Updated XFormers MonkeyPatch to handle GQA as used in Llama-2 70B. All the updated code is taken directly from the Transformers library: https://github.com/huggingface/transformers/commit/07360b6c9c9448d619a82798419ed291dfc6ac8f#diff-06392bad3b9e97be9ade60d4ac46f73b6809388f4d507c2ba1384ab872711c51 from their llama_modeling.py file.

* Catch configs without pretraining_tp

* Whitespace bug fix

Command had accidentally been moved out of if-else block.

* pre-commit formatting fixes

Thanks to

@winglian

src/axolotl/monkeypatch/llama_attn_hijack_xformers.py CHANGED
@@ -7,6 +7,7 @@ import math
7
  from typing import Optional, Tuple
8
 
9
  import torch
 
10
  import transformers.models.llama.modeling_llama
11
  from torch import nn
12
 
@@ -38,21 +39,48 @@ def xformers_forward(
38
  # pylint: disable=duplicate-code
39
  bsz, q_len, _ = hidden_states.size()
40
 
41
- query_states = (
42
- self.q_proj(hidden_states)
43
- .view(bsz, q_len, self.num_heads, self.head_dim)
44
- .transpose(1, 2)
45
- )
46
- key_states = (
47
- self.k_proj(hidden_states)
48
- .view(bsz, q_len, self.num_heads, self.head_dim)
49
- .transpose(1, 2)
50
- )
51
- value_states = (
52
- self.v_proj(hidden_states)
53
- .view(bsz, q_len, self.num_heads, self.head_dim)
54
- .transpose(1, 2)
55
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  kv_seq_len = key_states.shape[-2]
58
  if past_key_value is not None:
@@ -73,6 +101,14 @@ def xformers_forward(
73
 
74
  past_key_value = (key_states, value_states) if use_cache else None
75
 
 
 
 
 
 
 
 
 
76
  # We only apply xformers optimizations if we don't need to output the whole attention matrix
77
  if not output_attentions:
78
  query_states = query_states.transpose(1, 2)
@@ -128,10 +164,23 @@ def xformers_forward(
128
  f" {attn_output.size()}"
129
  )
130
 
131
- attn_output = attn_output.transpose(1, 2)
 
132
 
133
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
134
- attn_output = self.o_proj(attn_output)
 
 
 
 
 
 
 
 
 
 
 
 
135
  return attn_output, attn_weights, past_key_value
136
 
137
 
 
7
  from typing import Optional, Tuple
8
 
9
  import torch
10
+ import torch.nn.functional as F
11
  import transformers.models.llama.modeling_llama
12
  from torch import nn
13
 
 
39
  # pylint: disable=duplicate-code
40
  bsz, q_len, _ = hidden_states.size()
41
 
42
+ if not hasattr(self, "pretraining_tp"):
43
+ self.pretraining_tp = 1
44
+
45
+ if self.pretraining_tp > 1:
46
+ key_value_slicing = (
47
+ self.num_key_value_heads * self.head_dim
48
+ ) // self.pretraining_tp
49
+ query_slices = self.q_proj.weight.split(
50
+ (self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
51
+ )
52
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
53
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
54
+
55
+ query_states = [
56
+ F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
57
+ ]
58
+ query_states = torch.cat(query_states, dim=-1)
59
+
60
+ key_states = [
61
+ F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
62
+ ]
63
+ key_states = torch.cat(key_states, dim=-1)
64
+
65
+ value_states = [
66
+ F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
67
+ ]
68
+ value_states = torch.cat(value_states, dim=-1)
69
+
70
+ else:
71
+ query_states = self.q_proj(hidden_states)
72
+ key_states = self.k_proj(hidden_states)
73
+ value_states = self.v_proj(hidden_states)
74
+
75
+ query_states = query_states.view(
76
+ bsz, q_len, self.num_heads, self.head_dim
77
+ ).transpose(1, 2)
78
+ key_states = key_states.view(
79
+ bsz, q_len, self.num_key_value_heads, self.head_dim
80
+ ).transpose(1, 2)
81
+ value_states = value_states.view(
82
+ bsz, q_len, self.num_key_value_heads, self.head_dim
83
+ ).transpose(1, 2)
84
 
85
  kv_seq_len = key_states.shape[-2]
86
  if past_key_value is not None:
 
101
 
102
  past_key_value = (key_states, value_states) if use_cache else None
103
 
104
+ # repeat k/v heads if n_kv_heads < n_heads
105
+ key_states = transformers.models.llama.modeling_llama.repeat_kv(
106
+ key_states, self.num_key_value_groups
107
+ )
108
+ value_states = transformers.models.llama.modeling_llama.repeat_kv(
109
+ value_states, self.num_key_value_groups
110
+ )
111
+
112
  # We only apply xformers optimizations if we don't need to output the whole attention matrix
113
  if not output_attentions:
114
  query_states = query_states.transpose(1, 2)
 
164
  f" {attn_output.size()}"
165
  )
166
 
167
+ attn_output = attn_output.transpose(1, 2).contiguous()
168
+ # end x-formers vs. not x-formers if-else block
169
 
170
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
171
+
172
+ if self.pretraining_tp > 1:
173
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
174
+ o_proj_slices = self.o_proj.weight.split(
175
+ self.hidden_size // self.pretraining_tp, dim=1
176
+ )
177
+ attn_output = sum(
178
+ F.linear(attn_output[i], o_proj_slices[i])
179
+ for i in range(self.pretraining_tp)
180
+ )
181
+ else:
182
+ attn_output = self.o_proj(attn_output)
183
+
184
  return attn_output, attn_weights, past_key_value
185
 
186