OrionZheng
commited on
Commit
•
eca9f18
1
Parent(s):
baf5e8d
Update modeling_openmoe.py
Browse files- modeling_openmoe.py +96 -85
modeling_openmoe.py
CHANGED
@@ -48,40 +48,6 @@ logger = logging.get_logger(__name__)
|
|
48 |
|
49 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
50 |
|
51 |
-
class LlamaRotaryEmbedding(nn.Module):
|
52 |
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
53 |
-
super().__init__()
|
54 |
-
|
55 |
-
self.dim = dim
|
56 |
-
self.max_position_embeddings = max_position_embeddings
|
57 |
-
self.base = base
|
58 |
-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
59 |
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
60 |
-
|
61 |
-
# Build here to make `torch.jit.trace` work.
|
62 |
-
self._set_cos_sin_cache(
|
63 |
-
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
64 |
-
)
|
65 |
-
|
66 |
-
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
67 |
-
self.max_seq_len_cached = seq_len
|
68 |
-
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
69 |
-
|
70 |
-
freqs = torch.outer(t, self.inv_freq) # (seq_len, dim//2)
|
71 |
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
72 |
-
emb = torch.cat((freqs, freqs), dim=-1) # (seq_len, dim)
|
73 |
-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
74 |
-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
75 |
-
|
76 |
-
def forward(self, x, seq_len=None):
|
77 |
-
# x: [bs, num_attention_heads, seq_len, head_size]
|
78 |
-
if seq_len > self.max_seq_len_cached:
|
79 |
-
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
80 |
-
|
81 |
-
return (
|
82 |
-
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
83 |
-
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
84 |
-
)
|
85 |
|
86 |
def set_openmoe_args(
|
87 |
config: LlamaConfig,
|
@@ -191,6 +157,72 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
191 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
192 |
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
def rotate_half(x):
|
195 |
"""Rotates half the hidden dims of the input."""
|
196 |
x1 = x[..., : x.shape[-1] // 2]
|
@@ -198,33 +230,6 @@ def rotate_half(x):
|
|
198 |
return torch.cat((-x2, x1), dim=-1)
|
199 |
|
200 |
|
201 |
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
202 |
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
203 |
-
|
204 |
-
Args:
|
205 |
-
q (`torch.Tensor`): The query tensor.
|
206 |
-
k (`torch.Tensor`): The key tensor.
|
207 |
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
208 |
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
209 |
-
position_ids (`torch.Tensor`):
|
210 |
-
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
211 |
-
used to pass offsetted position ids when working with a KV-cache.
|
212 |
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
213 |
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
214 |
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
215 |
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
216 |
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
217 |
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
218 |
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
219 |
-
Returns:
|
220 |
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
221 |
-
"""
|
222 |
-
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
223 |
-
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
224 |
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
225 |
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
226 |
-
return q_embed, k_embed
|
227 |
-
|
228 |
def SwiGLU(x):
|
229 |
"""Gated linear unit activation function.
|
230 |
Args:
|
@@ -297,24 +302,15 @@ class OpenMoeAttention(nn.Module):
|
|
297 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
298 |
self.pretraining_tp = config.pretraining_tp
|
299 |
self.max_position_embeddings = config.max_position_embeddings
|
300 |
-
self.rope_theta = config.rope_theta
|
301 |
|
302 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
303 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
304 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
305 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
306 |
-
self.
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
self.rotary_emb = LlamaRotaryEmbedding(
|
311 |
-
self.head_dim,
|
312 |
-
max_position_embeddings=self.max_position_embeddings,
|
313 |
-
base=self.rope_theta,
|
314 |
-
)
|
315 |
-
else:
|
316 |
-
raise ValueError(f"Only Original RotaryEmbedding is supported yet")
|
317 |
-
|
318 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
319 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
320 |
|
@@ -350,29 +346,44 @@ class OpenMoeAttention(nn.Module):
|
|
350 |
key_states = self.k_proj(hidden_states)
|
351 |
value_states = self.v_proj(hidden_states)
|
352 |
|
353 |
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
354 |
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
355 |
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
356 |
|
357 |
-
kv_seq_len = key_states.shape[-2]
|
|
|
|
|
|
|
|
|
358 |
if past_key_value is not None:
|
359 |
-
kv_seq_len += past_key_value[0].shape[-2]
|
360 |
# reuse k, v, self_attention
|
361 |
-
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
362 |
-
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
363 |
|
364 |
past_key_value = (key_states, value_states) if use_cache else None
|
365 |
|
366 |
-
|
367 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
368 |
|
369 |
# repeat k/v heads if n_kv_heads < n_heads
|
370 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
371 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
372 |
|
373 |
if HAS_FLASH_ATTN and use_kernel:
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
376 |
query_states = query_states.transpose(1, 2)
|
377 |
key_states = key_states.transpose(1, 2)
|
378 |
value_states = value_states.transpose(1, 2)
|
|
|
48 |
|
49 |
_CONFIG_FOR_DOC = "LlamaConfig"
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def set_openmoe_args(
|
53 |
config: LlamaConfig,
|
|
|
157 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
158 |
|
159 |
|
160 |
+
def generate_fixed_pos_embedding(features, length, min_timescale=1.0, max_timescale=10000.0):
|
161 |
+
"""Generate Sin/Cos for Rotary Embeddings.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
features: an integer
|
165 |
+
length: an integer
|
166 |
+
min_timescale: an optional float
|
167 |
+
max_timescale: an optional float
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
output_sin: a float32 Tensor with shape [length, features]
|
171 |
+
output_cos: a float32 Tensor with shape [length, features]
|
172 |
+
"""
|
173 |
+
fraction = torch.arange(0, features, 2, dtype=torch.float32) / features
|
174 |
+
timescale = min_timescale * (max_timescale / min_timescale) ** fraction
|
175 |
+
rotational_frequency = 1.0 / timescale
|
176 |
+
|
177 |
+
sinusoid_inp = torch.einsum("i,j->ij", torch.arange(length, dtype=torch.float32), rotational_frequency)
|
178 |
+
|
179 |
+
sinusoid_inp = torch.cat([sinusoid_inp, sinusoid_inp], dim=-1)
|
180 |
+
|
181 |
+
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
|
182 |
+
|
183 |
+
|
184 |
+
def apply_rotary_embedding(q, k, cos, sin, decode=False, rotary_index=None):
|
185 |
+
# q: (bs, q_len, num_heads, head_dim)
|
186 |
+
# k: (bs, q_len [+past_kv_len], num_heads, head_dim)
|
187 |
+
# cos: (max_seq_len, head_dim)
|
188 |
+
# sin: (max_seq_len, head_dim)
|
189 |
+
# rotary_index: (bs, 1) # only used during decoding, when one query token is input at a time
|
190 |
+
"""Helper function to apply Rotary Embeddings."""
|
191 |
+
cos = cos.to(q.dtype)
|
192 |
+
sin = sin.to(q.dtype)
|
193 |
+
|
194 |
+
if len(k.shape) == 3: # for multi query attention
|
195 |
+
k = k.unsqueeze(2)
|
196 |
+
multiquery = True
|
197 |
+
else:
|
198 |
+
multiquery = False
|
199 |
+
|
200 |
+
batch, qlen, qheads, d = q.shape
|
201 |
+
kbatch, klen, kheads, kd = k.shape
|
202 |
+
assert batch == kbatch, f"{batch} != {kbatch}"
|
203 |
+
assert d == kd, f"{d} != {kd}"
|
204 |
+
if decode and qlen == 1 and rotary_index is not None:
|
205 |
+
qcos = cos[rotary_index, :] # (bs, 1, head_dim)
|
206 |
+
qsin = sin[rotary_index, :] # (bs, 1, head_dim)
|
207 |
+
qcos = qcos.unsqueeze(2) # (bs, q_len=1, 1, head_dim) # broadcast to all heads
|
208 |
+
qsin = qsin.unsqueeze(2) # (bs, q_len=1, 1, head_dim)
|
209 |
+
else:
|
210 |
+
qcos, qsin = cos[:qlen, :], sin[:qlen, :] # (q_len, head_dim)
|
211 |
+
qcos = qcos.unsqueeze(0).unsqueeze(2) # (1, q_len, 1, head_dim)
|
212 |
+
qsin = qsin.unsqueeze(0).unsqueeze(2)
|
213 |
+
|
214 |
+
kcos, ksin = cos[:klen, :], sin[:klen, :] # (k_len, head_dim)
|
215 |
+
kcos = kcos.unsqueeze(0).unsqueeze(2) # (1, k_len, 1, head_dim) # broadcast to the whole batch, broadcast to all heads
|
216 |
+
ksin = ksin.unsqueeze(0).unsqueeze(2) # (1, k_len, 1, head_dim)
|
217 |
+
out_q = (q * qcos) + (rotate_half(q) * qsin)
|
218 |
+
out_k = (k * kcos) + (rotate_half(k) * ksin)
|
219 |
+
|
220 |
+
if multiquery:
|
221 |
+
out_k = out_k.squeeze(2)
|
222 |
+
|
223 |
+
return out_q, out_k
|
224 |
+
|
225 |
+
|
226 |
def rotate_half(x):
|
227 |
"""Rotates half the hidden dims of the input."""
|
228 |
x1 = x[..., : x.shape[-1] // 2]
|
|
|
230 |
return torch.cat((-x2, x1), dim=-1)
|
231 |
|
232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
def SwiGLU(x):
|
234 |
"""Gated linear unit activation function.
|
235 |
Args:
|
|
|
302 |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
303 |
self.pretraining_tp = config.pretraining_tp
|
304 |
self.max_position_embeddings = config.max_position_embeddings
|
|
|
305 |
|
306 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
|
307 |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
308 |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
309 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
310 |
+
sin, cos = generate_fixed_pos_embedding(self.head_dim, self.max_position_embeddings, 1.0, 1e4)
|
311 |
+
self.register_buffer('sin', sin)
|
312 |
+
self.register_buffer('cos', cos)
|
313 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
315 |
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
316 |
|
|
|
346 |
key_states = self.k_proj(hidden_states)
|
347 |
value_states = self.v_proj(hidden_states)
|
348 |
|
349 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
350 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
351 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
352 |
|
353 |
+
kv_seq_len = key_states.shape[-2]
|
354 |
+
if past_key_value is not None:
|
355 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
356 |
+
# cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
357 |
+
# query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
358 |
if past_key_value is not None:
|
|
|
359 |
# reuse k, v, self_attention
|
360 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
361 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
362 |
|
363 |
past_key_value = (key_states, value_states) if use_cache else None
|
364 |
|
365 |
+
query_states = query_states.transpose(1, 2)
|
366 |
+
key_states = key_states.transpose(1, 2)
|
367 |
+
max_length = max(query_states.shape[1], key_states.shape[1])
|
368 |
+
assert max_length <= self.sin.shape[0]
|
369 |
+
sin, cos = self.sin[:max_length], self.cos[:max_length]
|
370 |
+
# TODO: for inference, we can add emb kv into cache to avoid computation
|
371 |
+
query_states, key_states = apply_rotary_embedding(
|
372 |
+
query_states, key_states, cos, sin, decode=True if q_len == 1 else False, rotary_index=position_ids
|
373 |
+
)
|
374 |
+
query_states = query_states.transpose(1, 2)
|
375 |
+
key_states = key_states.transpose(1, 2)
|
376 |
|
377 |
# repeat k/v heads if n_kv_heads < n_heads
|
378 |
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
379 |
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
380 |
|
381 |
if HAS_FLASH_ATTN and use_kernel:
|
382 |
+
# If we use `from flash_attn import flash_attn_func` directly,
|
383 |
+
# AutoModelForCausalLM.from_pretrained will treat flash_attn as a compulsory dependency and raise error if cannot find.
|
384 |
+
# Here is a workaround to avoid the error.
|
385 |
+
exec("from flash_attn import flash_attn_func")
|
386 |
+
|
387 |
query_states = query_states.transpose(1, 2)
|
388 |
key_states = key_states.transpose(1, 2)
|
389 |
value_states = value_states.transpose(1, 2)
|