Files changed (3) hide show
  1. modeling_lora.py +2 -1
  2. modeling_xlm_roberta.py +1 -0
  3. rotary.py +11 -14
modeling_lora.py CHANGED
@@ -11,6 +11,7 @@ from torch.nn import Parameter
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
 
14
  from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
15
  XLMRobertaPreTrainedModel)
16
 
@@ -328,7 +329,7 @@ class XLMRobertaLoRA(XLMRobertaPreTrainedModel):
328
  )
329
  else: # initializing new adapters
330
  roberta = XLMRobertaModel.from_pretrained(
331
- pretrained_model_name_or_path, *model_args, **kwargs
332
  )
333
  return cls(config, roberta=roberta)
334
 
 
11
  from torch.nn import functional as F
12
  from transformers import PretrainedConfig
13
 
14
+ from .rotary import RotaryEmbedding
15
  from .modeling_xlm_roberta import (XLMRobertaFlashConfig, XLMRobertaModel,
16
  XLMRobertaPreTrainedModel)
17
 
 
329
  )
330
  else: # initializing new adapters
331
  roberta = XLMRobertaModel.from_pretrained(
332
+ pretrained_model_name_or_path, *model_args, use_flash_attn=config.use_flash_attn, **kwargs
333
  )
334
  return cls(config, roberta=roberta)
335
 
modeling_xlm_roberta.py CHANGED
@@ -30,6 +30,7 @@ from transformers.models.bert.modeling_bert import (
30
  from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
  XLMRobertaLMHead
32
 
 
33
  from .block import Block
34
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
35
  from .embedding import XLMRobertaEmbeddings
 
30
  from transformers.models.xlm_roberta.modeling_xlm_roberta import \
31
  XLMRobertaLMHead
32
 
33
+ from .rotary import RotaryEmbedding
34
  from .block import Block
35
  from .configuration_xlm_roberta import XLMRobertaFlashConfig
36
  from .embedding import XLMRobertaEmbeddings
rotary.py CHANGED
@@ -9,6 +9,17 @@ from typing import Optional, Tuple, Union
9
  import torch
10
  from einops import rearrange, repeat
11
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def rotate_half(x, interleaved=False):
14
  if not interleaved:
@@ -60,8 +71,6 @@ class ApplyRotaryEmb(torch.autograd.Function):
60
  cu_seqlens: Optional[torch.Tensor] = None,
61
  max_seqlen: Optional[int] = None,
62
  ):
63
- from flash_attn.ops.triton.rotary import apply_rotary
64
-
65
  out = apply_rotary(
66
  x,
67
  cos,
@@ -88,8 +97,6 @@ class ApplyRotaryEmb(torch.autograd.Function):
88
 
89
  @staticmethod
90
  def backward(ctx, do):
91
- from flash_attn.ops.triton.rotary import apply_rotary
92
-
93
  seqlen_offsets = ctx.seqlen_offsets
94
  if seqlen_offsets is None:
95
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
@@ -171,8 +178,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
171
  if cos_k is None and sin_k is None and qkv.is_contiguous():
172
 
173
  if use_flash_attn:
174
- from flash_attn.ops.triton.rotary import apply_rotary
175
-
176
  # Call 1 kernel instead of 2 kernels
177
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
178
  # dimensions, we get the same tensor
@@ -203,8 +208,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
203
  )
204
  qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
205
  else:
206
- from flash_attn.ops.triton.rotary import apply_rotary
207
-
208
  cos_k = cos if cos_k is None else cos_k
209
  sin_k = sin if sin_k is None else sin_k
210
  q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
@@ -241,8 +244,6 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
241
 
242
  @staticmethod
243
  def backward(ctx, dqkv):
244
- from flash_attn.ops.triton.rotary import apply_rotary
245
-
246
  seqlen_offsets = ctx.seqlen_offsets
247
  if seqlen_offsets is None:
248
  cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
@@ -340,8 +341,6 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
340
  cu_seqlens: Optional[torch.Tensor] = None,
341
  max_seqlen: Optional[int] = None,
342
  ):
343
- from flash_attn.ops.triton.rotary import apply_rotary
344
-
345
  # batch, seqlen, two, nheads, headdim = kv.shape
346
  assert kv.shape[-3] == 2
347
  k = kv[..., 0, :, :]
@@ -369,8 +368,6 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
369
 
370
  @staticmethod
371
  def backward(ctx, dkv):
372
- from flash_attn.ops.triton.rotary import apply_rotary
373
-
374
  seqlen_offsets = ctx.seqlen_offsets
375
  if seqlen_offsets is None:
376
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
 
9
  import torch
10
  from einops import rearrange, repeat
11
 
12
+ if torch.cuda.is_available():
13
+ try:
14
+ from flash_attn.ops.triton.rotary import apply_rotary
15
+ except ImportError:
16
+
17
+ def apply_rotary(*args, **kwargs):
18
+ raise RuntimeError(
19
+ "FlashAttention is not installed. To proceed with training, please install FlashAttention. "
20
+ "For inference, you have two options: either install FlashAttention or disable it by setting use_flash_attn=False when loading the model."
21
+ )
22
+
23
 
24
  def rotate_half(x, interleaved=False):
25
  if not interleaved:
 
71
  cu_seqlens: Optional[torch.Tensor] = None,
72
  max_seqlen: Optional[int] = None,
73
  ):
 
 
74
  out = apply_rotary(
75
  x,
76
  cos,
 
97
 
98
  @staticmethod
99
  def backward(ctx, do):
 
 
100
  seqlen_offsets = ctx.seqlen_offsets
101
  if seqlen_offsets is None:
102
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors
 
178
  if cos_k is None and sin_k is None and qkv.is_contiguous():
179
 
180
  if use_flash_attn:
 
 
181
  # Call 1 kernel instead of 2 kernels
182
  # We need qkv to be contiguous so that when we reshape to combine (3, nheads)
183
  # dimensions, we get the same tensor
 
208
  )
209
  qkv = torch.stack((q_rot, k_rot, qkv[:, :, 2]), dim=2)
210
  else:
 
 
211
  cos_k = cos if cos_k is None else cos_k
212
  sin_k = sin if sin_k is None else sin_k
213
  q, k = qkv[..., 0, :, :], qkv[..., 1, :, :]
 
244
 
245
  @staticmethod
246
  def backward(ctx, dqkv):
 
 
247
  seqlen_offsets = ctx.seqlen_offsets
248
  if seqlen_offsets is None:
249
  cos, sin, cos_k, sin_k, cu_seqlens, seqlen_offsets = ctx.saved_tensors
 
341
  cu_seqlens: Optional[torch.Tensor] = None,
342
  max_seqlen: Optional[int] = None,
343
  ):
 
 
344
  # batch, seqlen, two, nheads, headdim = kv.shape
345
  assert kv.shape[-3] == 2
346
  k = kv[..., 0, :, :]
 
368
 
369
  @staticmethod
370
  def backward(ctx, dkv):
 
 
371
  seqlen_offsets = ctx.seqlen_offsets
372
  if seqlen_offsets is None:
373
  cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors