fix-device-issues
#38
by
jupyterjazz
- opened
- modeling_lora.py +2 -1
- modeling_xlm_roberta.py +1 -0
- rotary.py +11 -14
modeling_lora.py
CHANGED
@@ -11,6 +11,7 @@ from torch.nn import Parameter
|
|
11 |
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
|
|
14 |
from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
|
15 |
XLMRobertaPreTrainedModel)
|
16 |
|
@@ -328,7 +329,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
|
|
328 |
)
|
329 |
else: # initializing new adapters
|
330 |
roberta = XLMRobertaModel.from_pretrained(
|
331 |
-
pretrained_model_name_or_path, *model_args, **kwargs
|
332 |
)
|
333 |
return cls(config, roberta=roberta)
|
334 |
|
|
|
11 |
from torch.nn import functional as F
|
12 |
from transformers import PretrainedConfig
|
13 |
|
14 |
+
from .rotary import RotaryEmbedding
|
15 |
from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
|
16 |
XLMRobertaPreTrainedModel)
|
17 |
|
|
|
329 |
)
|
330 |
else: # initializing new adapters
|
331 |
roberta = XLMRobertaModel.from_pretrained(
|
332 |
+
pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
|
333 |
)
|
334 |
return cls(config, roberta=roberta)
|
335 |
|
modeling_xlm_roberta.py
CHANGED
@@ -30,6 +30,7 @@ from transformers.models.bert.modeling_bert import (
|
|
30 |
from transformers.models.xlm_roberta.modeling_xlm_roberta import \
|
31 |
XLMRobertaLMHead
|
32 |
|
|
|
33 |
from .block import Block
|
34 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
35 |
from .embedding import XLMRobertaEmbeddings
|
|
|
30 |
from transformers.models.xlm_roberta.modeling_xlm_roberta import \
|
31 |
XLMRobertaLMHead
|
32 |
|
33 |
+
from .rotary import RotaryEmbedding
|
34 |
from .block import Block
|
35 |
from .configuration_xlm_roberta import XLMRobertaFlashConfig
|
36 |
from .embedding import XLMRobertaEmbeddings
|
rotary.py
CHANGED
@@ -9,6 +9,17 @@ from typing import Optional, Tuple, Union
|
|
9 |
import torch
|
10 |
from einops import rearrange, repeat
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
def rotate_half(x, interleaved=False):
|
14 |
if not interleaved:
|
@@ -60,8 +71,6 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|
60 |
cu_seqlens: Optional[torch.Tensor] = None,
|
61 |
max_seqlen: Optional[int] = None,
|
62 |
):
|
63 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
64 |
-
|
65 |
out = apply_rotary(
|
66 |
x,
|
67 |
cos,
|
@@ -88,8 +97,6 @@ class ApplyRotaryEmb(torch.autograd.Function):
|
|
88 |
|
89 |
@staticmethod
|
90 |
def backward(ctx, do):
|
91 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
92 |
-
|
93 |
seqlen_offsets = ctx.seqlen_offsets
|
94 |
if seqlen_offsets is None:
|
95 |
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
@@ -171,8 +178,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
171 |
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
172 |
|
173 |
if use_flash_attn:
|
174 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
175 |
-
|
176 |
# Call 1 kernel instead of 2 kernels
|
177 |
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
178 |
# dimensions, we get the same tensor
|
@@ -203,8 +208,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
203 |
)
|
204 |
qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
|
205 |
else:
|
206 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
207 |
-
|
208 |
cos_k = cos if cos_k is None else cos_k
|
209 |
sin_k = sin if sin_k is None else sin_k
|
210 |
q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
|
@@ -241,8 +244,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
|
|
241 |
|
242 |
@staticmethod
|
243 |
def backward(ctx, dqkv):
|
244 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
245 |
-
|
246 |
seqlen_offsets = ctx.seqlen_offsets
|
247 |
if seqlen_offsets is None:
|
248 |
cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
@@ -340,8 +341,6 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
|
|
340 |
cu_seqlens: Optional[torch.Tensor] = None,
|
341 |
max_seqlen: Optional[int] = None,
|
342 |
):
|
343 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
344 |
-
|
345 |
# batch, seqlen, two, nheads, headdim = kv.shape
|
346 |
assert kv.shape[-3] == 2
|
347 |
k = kv[..., 0, :, :]
|
@@ -369,8 +368,6 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
|
|
369 |
|
370 |
@staticmethod
|
371 |
def backward(ctx, dkv):
|
372 |
-
from flash_attn.ops.triton.rotary import apply_rotary
|
373 |
-
|
374 |
seqlen_offsets = ctx.seqlen_offsets
|
375 |
if seqlen_offsets is None:
|
376 |
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
|
|
9 |
import torch
|
10 |
from einops import rearrange, repeat
|
11 |
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
try:
|
14 |
+
from flash_attn.ops.triton.rotary import apply_rotary
|
15 |
+
except ImportError:
|
16 |
+
|
17 |
+
def apply_rotary(*args, **kwargs):
|
18 |
+
raise RuntimeError(
|
19 |
+
"FlashAttention is not installed. To proceed with training, please install FlashAttention. "
|
20 |
+
"For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
|
21 |
+
)
|
22 |
+
|
23 |
|
24 |
def rotate_half(x, interleaved=False):
|
25 |
if not interleaved:
|
|
|
71 |
cu_seqlens: Optional[torch.Tensor] = None,
|
72 |
max_seqlen: Optional[int] = None,
|
73 |
):
|
|
|
|
|
74 |
out = apply_rotary(
|
75 |
x,
|
76 |
cos,
|
|
|
97 |
|
98 |
@staticmethod
|
99 |
def backward(ctx, do):
|
|
|
|
|
100 |
seqlen_offsets = ctx.seqlen_offsets
|
101 |
if seqlen_offsets is None:
|
102 |
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
|
|
178 |
if cos_k is None and sin_k is None and qkv.is_contiguous():
|
179 |
|
180 |
if use_flash_attn:
|
|
|
|
|
181 |
# Call 1 kernel instead of 2 kernels
|
182 |
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
|
183 |
# dimensions, we get the same tensor
|
|
|
208 |
)
|
209 |
qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
|
210 |
else:
|
|
|
|
|
211 |
cos_k = cos if cos_k is None else cos_k
|
212 |
sin_k = sin if sin_k is None else sin_k
|
213 |
q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
|
|
|
244 |
|
245 |
@staticmethod
|
246 |
def backward(ctx, dqkv):
|
|
|
|
|
247 |
seqlen_offsets = ctx.seqlen_offsets
|
248 |
if seqlen_offsets is None:
|
249 |
cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|
|
|
341 |
cu_seqlens: Optional[torch.Tensor] = None,
|
342 |
max_seqlen: Optional[int] = None,
|
343 |
):
|
|
|
|
|
344 |
# batch, seqlen, two, nheads, headdim = kv.shape
|
345 |
assert kv.shape[-3] == 2
|
346 |
k = kv[..., 0, :, :]
|
|
|
368 |
|
369 |
@staticmethod
|
370 |
def backward(ctx, dkv):
|
|
|
|
|
371 |
seqlen_offsets = ctx.seqlen_offsets
|
372 |
if seqlen_offsets is None:
|
373 |
cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
|