yangapku commited on
Commit
26fad65
1 Parent(s): f2e5005

update config about model precision, fix apply_rotary_pos_emb

Browse files
Files changed (3) hide show
  1. config.json +3 -2
  2. configuration_qwen.py +5 -1
  3. modeling_qwen.py +80 -35
config.json CHANGED
@@ -10,12 +10,13 @@
10
  },
11
  "attn_pdrop": 0.0,
12
  "bf16": false,
 
 
13
  "bias_dropout_fusion": true,
14
  "bos_token_id": 151643,
15
  "embd_pdrop": 0.1,
16
  "eos_token_id": 151643,
17
  "ffn_hidden_size": 22016,
18
- "fp16": false,
19
  "initializer_range": 0.02,
20
  "kv_channels": 128,
21
  "layer_norm_epsilon": 1e-05,
@@ -38,7 +39,7 @@
38
  "tokenizer_type": "QWenTokenizer",
39
  "transformers_version": "4.31.0",
40
  "use_cache": true,
41
- "use_flash_attn": true,
42
  "vocab_size": 151936,
43
  "use_dynamic_ntk": true,
44
  "use_logn_attn": true
 
10
  },
11
  "attn_pdrop": 0.0,
12
  "bf16": false,
13
+ "fp16": false,
14
+ "fp32": false,
15
  "bias_dropout_fusion": true,
16
  "bos_token_id": 151643,
17
  "embd_pdrop": 0.1,
18
  "eos_token_id": 151643,
19
  "ffn_hidden_size": 22016,
 
20
  "initializer_range": 0.02,
21
  "kv_channels": 128,
22
  "layer_norm_epsilon": 1e-05,
 
39
  "tokenizer_type": "QWenTokenizer",
40
  "transformers_version": "4.31.0",
41
  "use_cache": true,
42
+ "use_flash_attn": "auto",
43
  "vocab_size": 151936,
44
  "use_dynamic_ntk": true,
45
  "use_logn_attn": true
configuration_qwen.py CHANGED
@@ -31,7 +31,9 @@ class QWenConfig(PretrainedConfig):
31
  use_cache=True,
32
  eos_token_id=151643,
33
  apply_residual_connection_post_layernorm=False,
34
- bf16=True,
 
 
35
  kv_channels=128,
36
  rotary_pct=1.0,
37
  rotary_emb_base=10000,
@@ -63,6 +65,8 @@ class QWenConfig(PretrainedConfig):
63
  apply_residual_connection_post_layernorm
64
  )
65
  self.bf16 = bf16
 
 
66
  self.kv_channels = kv_channels
67
  self.rotary_pct = rotary_pct
68
  self.rotary_emb_base = rotary_emb_base
 
31
  use_cache=True,
32
  eos_token_id=151643,
33
  apply_residual_connection_post_layernorm=False,
34
+ bf16=False,
35
+ fp16=False,
36
+ fp32=False,
37
  kv_channels=128,
38
  rotary_pct=1.0,
39
  rotary_emb_base=10000,
 
65
  apply_residual_connection_post_layernorm
66
  )
67
  self.bf16 = bf16
68
+ self.fp16 = fp16
69
+ self.fp32 = fp32
70
  self.kv_channels = kv_channels
71
  self.rotary_pct = rotary_pct
72
  self.rotary_emb_base = rotary_emb_base
modeling_qwen.py CHANGED
@@ -32,26 +32,13 @@ except ImportError:
32
  rearrange = None
33
  from torch import nn
34
 
35
- try:
36
- from flash_attn.layers.rotary import apply_rotary_emb_func
37
- from einops import rearrange
38
-
39
- use_flash_rotary = True
40
- except ImportError:
41
- use_flash_rotary = False
42
- print(
43
- "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
44
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
45
- )
46
 
47
- try:
48
- from flash_attn.ops.rms_norm import rms_norm
49
- except ImportError:
50
- rms_norm = None
51
- print(
52
- "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
53
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
54
- )
55
 
56
  from .configuration_qwen import QWenConfig
57
  from .qwen_generation_utils import (
@@ -70,16 +57,6 @@ _CONFIG_FOR_DOC = "QWenConfig"
70
 
71
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
72
 
73
- try:
74
- from flash_attn.flash_attn_interface import flash_attn_unpadded_func
75
- except ImportError:
76
- flash_attn_unpadded_func = None
77
- print(
78
- "Warning: import flash_attn fail, please install FlashAttention "
79
- "https://github.com/Dao-AILab/flash-attention"
80
- )
81
-
82
-
83
  class FlashSelfAttention(torch.nn.Module):
84
  def __init__(
85
  self,
@@ -388,7 +365,7 @@ class QWenAttention(nn.Module):
388
  present = None
389
 
390
  if self.use_logn_attn and not self.training:
391
- if self.logn_tensor.device != query.device:
392
  self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
393
  seq_start = key.size(1) - query.size(1)
394
  seq_end = key.size(1)
@@ -775,11 +752,79 @@ class QWenLMHeadModel(QWenPreTrainedModel):
775
 
776
  def __init__(self, config):
777
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
  self.transformer = QWenModel(config)
779
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
780
- assert not (
781
- config.bf16 and config.fp16
782
- ), "In config, bf16 and fp16 cannot both be true"
783
  if config.bf16:
784
  self.transformer.bfloat16()
785
  self.lm_head.bfloat16()
@@ -1040,8 +1085,8 @@ def _rotate_half(x):
1040
  return torch.cat((-x2, x1), dim=-1)
1041
 
1042
 
1043
- def apply_rotary_pos_emb(t, freqs, use_flash_rotary=False):
1044
- if use_flash_rotary:
1045
  t_ = t.float()
1046
  freqs = freqs.squeeze(0).squeeze(1)
1047
  cos = freqs[:, : freqs.shape[-1] // 2].cos()
 
32
  rearrange = None
33
  from torch import nn
34
 
35
+ SUPPORT_CUDA = torch.cuda.is_available()
36
+ SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
37
+ SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
 
 
 
 
 
 
 
 
38
 
39
+ apply_rotary_emb_func = None
40
+ rms_norm = None
41
+ flash_attn_unpadded_func = None
 
 
 
 
 
42
 
43
  from .configuration_qwen import QWenConfig
44
  from .qwen_generation_utils import (
 
57
 
58
  QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
59
 
 
 
 
 
 
 
 
 
 
 
60
  class FlashSelfAttention(torch.nn.Module):
61
  def __init__(
62
  self,
 
365
  present = None
366
 
367
  if self.use_logn_attn and not self.training:
368
+ if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
369
  self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
370
  seq_start = key.size(1) - query.size(1)
371
  seq_end = key.size(1)
 
752
 
753
  def __init__(self, config):
754
  super().__init__(config)
755
+ assert (
756
+ config.bf16 + config.fp16 + config.fp32 <= 1
757
+ ), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
758
+
759
+ autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
760
+
761
+ if autoset_precision:
762
+ if SUPPORT_BF16:
763
+ logger.warn(
764
+ "The model is automatically converting to bf16 for faster inference. "
765
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
766
+ )
767
+ config.bf16 = True
768
+ elif SUPPORT_FP16:
769
+ logger.warn(
770
+ "The model is automatically converting to fp16 for faster inference. "
771
+ "If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
772
+ )
773
+ config.fp16 = True
774
+ else:
775
+ config.fp32 = True
776
+
777
+ if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
778
+ logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
779
+ if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
780
+ logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
781
+ if config.fp32:
782
+ if SUPPORT_BF16:
783
+ logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
784
+ elif SUPPORT_FP16:
785
+ logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
786
+
787
+ if config.use_flash_attn == "auto":
788
+ if config.bf16 or config.fp16:
789
+ logger.warn("Try importing flash-attention for faster inference...")
790
+ config.use_flash_attn = True
791
+ else:
792
+ config.use_flash_attn = False
793
+ if config.use_flash_attn and config.fp32:
794
+ logger.warn("Flash attention will be disabled because it does NOT support fp32.")
795
+
796
+ if config.use_flash_attn:
797
+ global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
798
+ try:
799
+ from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
800
+ apply_rotary_emb_func = __apply_rotary_emb_func
801
+ except ImportError:
802
+ logger.warn(
803
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
804
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
805
+ )
806
+
807
+ try:
808
+ from flash_attn.ops.rms_norm import rms_norm as __rms_norm
809
+ rms_norm = __rms_norm
810
+ except ImportError:
811
+ logger.warn(
812
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
813
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
814
+ )
815
+
816
+ try:
817
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
818
+ flash_attn_unpadded_func = __flash_attn_unpadded_func
819
+ except ImportError:
820
+ logger.warn(
821
+ "Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
822
+ "https://github.com/Dao-AILab/flash-attention"
823
+ )
824
+
825
  self.transformer = QWenModel(config)
826
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
827
+
 
 
828
  if config.bf16:
829
  self.transformer.bfloat16()
830
  self.lm_head.bfloat16()
 
1085
  return torch.cat((-x2, x1), dim=-1)
1086
 
1087
 
1088
+ def apply_rotary_pos_emb(t, freqs):
1089
+ if apply_rotary_emb_func is not None:
1090
  t_ = t.float()
1091
  freqs = freqs.squeeze(0).squeeze(1)
1092
  cos = freqs[:, : freqs.shape[-1] // 2].cos()