JustinLin610 commited on
Commit
cbf815e
1 Parent(s): f6498e5

support cpu inference, format file (#9)

Browse files

- support cpu inference, fix conflicts between fp32 and flash-attn (9f94ac27401a0fce0048ac05e4ab04880af57928)

Files changed (1) hide show
  1. modeling_qwen.py +64 -25
modeling_qwen.py CHANGED
@@ -15,6 +15,7 @@ from torch.cuda.amp import autocast
15
  from torch.nn import CrossEntropyLoss
16
  from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
  from transformers.generation.logits_process import LogitsProcessorList
 
18
  if TYPE_CHECKING:
19
  from transformers.generation.streamers import BaseStreamer
20
  from transformers.generation.utils import GenerateOutput
@@ -38,15 +39,19 @@ try:
38
  use_flash_rotary = True
39
  except ImportError:
40
  use_flash_rotary = False
41
- print("Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
42
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary")
 
 
43
 
44
  try:
45
  from flash_attn.ops.rms_norm import rms_norm
46
  except ImportError:
47
  rms_norm = None
48
- print("Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
49
- "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm")
 
 
50
 
51
  from .configuration_qwen import QWenConfig
52
  from .qwen_generation_utils import (
@@ -69,8 +74,10 @@ try:
69
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func
70
  except ImportError:
71
  flash_attn_unpadded_func = None
72
- print("Warning: import flash_attn fail, please install FlashAttention "
73
- "https://github.com/Dao-AILab/flash-attention")
 
 
74
 
75
 
76
  class FlashSelfAttention(torch.nn.Module):
@@ -177,8 +184,12 @@ class QWenAttention(nn.Module):
177
  config.hidden_size, self.projection_size, bias=not config.no_bias
178
  )
179
 
180
- self.is_fp32 = not(config.bf16 or config.fp16)
181
- if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
 
 
 
 
182
  self.core_attention_flash = FlashSelfAttention(
183
  causal=True, attention_dropout=config.attn_pdrop
184
  )
@@ -197,14 +208,15 @@ class QWenAttention(nn.Module):
197
  if self.rotary_ndims is not None
198
  else self.hidden_size_per_attention_head
199
  )
200
- self.rotary_emb = RotaryEmbedding(
201
- dim, base=config.rotary_emb_base
202
- )
203
 
204
  self.use_dynamic_ntk = config.use_dynamic_ntk
205
  self.use_logn_attn = config.use_logn_attn
206
 
207
- logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, 32768)]
 
 
 
208
  self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
209
  self._ntk_cached = 1.0
210
 
@@ -335,14 +347,20 @@ class QWenAttention(nn.Module):
335
  if layer_past:
336
  # layer past[0] shape: bs * seq_len * head_num * dim
337
  kv_seq_len += layer_past[0].shape[1]
338
- if self.use_dynamic_ntk and kv_seq_len == hidden_states.size()[1] and not self.training:
 
 
 
 
339
  context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
340
  ntk_alpha = 2 ** math.ceil(context_value) - 1
341
  ntk_alpha = max(ntk_alpha, 1)
342
  self._ntk_cached = ntk_alpha
343
  else:
344
  ntk_alpha = self._ntk_cached
345
- rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(hidden_states.device)
 
 
346
 
347
  if rotary_pos_emb is not None:
348
  if isinstance(rotary_pos_emb, tuple):
@@ -377,7 +395,12 @@ class QWenAttention(nn.Module):
377
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
378
  query = query * logn_tensor.expand_as(query)
379
 
380
- if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
 
 
 
 
 
381
  q, k, v = query, key, value
382
  context_layer = self.core_attention_flash(q, k, v)
383
 
@@ -398,7 +421,11 @@ class QWenAttention(nn.Module):
398
  attn_output = self.c_proj(context_layer)
399
  outputs = (attn_output, present)
400
  if output_attentions:
401
- if self.use_flash_attn and flash_attn_unpadded_func is not None and not self.is_fp32:
 
 
 
 
402
  raise ValueError("Cannot output attentions while using flash-attn")
403
  else:
404
  outputs += (attn_weight,)
@@ -750,7 +777,9 @@ class QWenLMHeadModel(QWenPreTrainedModel):
750
  super().__init__(config)
751
  self.transformer = QWenModel(config)
752
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
753
- assert not(config.bf16 and config.fp16), ("In config, bf16 and fp16 cannot both be true")
 
 
754
  if config.bf16:
755
  self.transformer.bfloat16()
756
  self.lm_head.bfloat16()
@@ -929,21 +958,25 @@ class QWenLMHeadModel(QWenPreTrainedModel):
929
  generation_config: Optional[GenerationConfig] = None,
930
  logits_processor: Optional[LogitsProcessorList] = None,
931
  stopping_criteria: Optional[StoppingCriteriaList] = None,
932
- prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
 
 
933
  synced_gpus: Optional[bool] = None,
934
  streamer: Optional["BaseStreamer"] = None,
935
  **kwargs,
936
  ) -> Union[GenerateOutput, torch.LongTensor]:
937
  # Process stop_words_ids.
938
- stop_words_ids = kwargs.pop('stop_words_ids', None)
939
  if stop_words_ids is None and generation_config is not None:
940
- stop_words_ids = getattr(generation_config, 'stop_words_ids', None)
941
  if stop_words_ids is None:
942
- stop_words_ids = getattr(self.generation_config, 'stop_words_ids', None)
943
 
944
  if stop_words_ids is not None:
945
  stop_words_logits_processor = StopWordsLogitsProcessor(
946
- stop_words_ids=stop_words_ids, eos_token_id=self.generation_config.eos_token_id)
 
 
947
  if logits_processor is None:
948
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
949
  else:
@@ -978,7 +1011,13 @@ class RotaryEmbedding(torch.nn.Module):
978
  seqlen = max_seq_len + offset
979
  if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
980
  base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
981
- self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=self.inv_freq.device).float() / self.dim))
 
 
 
 
 
 
982
  self._seq_len_cached = seqlen
983
  self._ntk_alpha_cached = ntk_alpha
984
  seq = torch.arange(seqlen, device=self.inv_freq.device)
@@ -1028,8 +1067,8 @@ class RMSNorm(torch.nn.Module):
1028
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1029
 
1030
  def forward(self, x):
1031
- if rms_norm is not None:
1032
  return rms_norm(x, self.weight, self.eps)
1033
  else:
1034
  output = self._norm(x.float()).type_as(x)
1035
- return output * self.weight
 
15
  from torch.nn import CrossEntropyLoss
16
  from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
17
  from transformers.generation.logits_process import LogitsProcessorList
18
+
19
  if TYPE_CHECKING:
20
  from transformers.generation.streamers import BaseStreamer
21
  from transformers.generation.utils import GenerateOutput
 
39
  use_flash_rotary = True
40
  except ImportError:
41
  use_flash_rotary = False
42
+ print(
43
+ "Warning: import flash_attn rotary fail, please install FlashAttention rotary to get better performance "
44
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
45
+ )
46
 
47
  try:
48
  from flash_attn.ops.rms_norm import rms_norm
49
  except ImportError:
50
  rms_norm = None
51
+ print(
52
+ "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get better performance "
53
+ "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
54
+ )
55
 
56
  from .configuration_qwen import QWenConfig
57
  from .qwen_generation_utils import (
 
74
  from flash_attn.flash_attn_interface import flash_attn_unpadded_func
75
  except ImportError:
76
  flash_attn_unpadded_func = None
77
+ print(
78
+ "Warning: import flash_attn fail, please install FlashAttention "
79
+ "https://github.com/Dao-AILab/flash-attention"
80
+ )
81
 
82
 
83
  class FlashSelfAttention(torch.nn.Module):
 
184
  config.hidden_size, self.projection_size, bias=not config.no_bias
185
  )
186
 
187
+ self.is_fp32 = not (config.bf16 or config.fp16)
188
+ if (
189
+ self.use_flash_attn
190
+ and flash_attn_unpadded_func is not None
191
+ and not self.is_fp32
192
+ ):
193
  self.core_attention_flash = FlashSelfAttention(
194
  causal=True, attention_dropout=config.attn_pdrop
195
  )
 
208
  if self.rotary_ndims is not None
209
  else self.hidden_size_per_attention_head
210
  )
211
+ self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
 
 
212
 
213
  self.use_dynamic_ntk = config.use_dynamic_ntk
214
  self.use_logn_attn = config.use_logn_attn
215
 
216
+ logn_list = [
217
+ math.log(i, self.seq_length) if i > self.seq_length else 1
218
+ for i in range(1, 32768)
219
+ ]
220
  self.logn_tensor = torch.Tensor(logn_list)[None, :, None, None]
221
  self._ntk_cached = 1.0
222
 
 
347
  if layer_past:
348
  # layer past[0] shape: bs * seq_len * head_num * dim
349
  kv_seq_len += layer_past[0].shape[1]
350
+ if (
351
+ self.use_dynamic_ntk
352
+ and kv_seq_len == hidden_states.size()[1]
353
+ and not self.training
354
+ ):
355
  context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
356
  ntk_alpha = 2 ** math.ceil(context_value) - 1
357
  ntk_alpha = max(ntk_alpha, 1)
358
  self._ntk_cached = ntk_alpha
359
  else:
360
  ntk_alpha = self._ntk_cached
361
+ rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha).to(
362
+ hidden_states.device
363
+ )
364
 
365
  if rotary_pos_emb is not None:
366
  if isinstance(rotary_pos_emb, tuple):
 
395
  logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
396
  query = query * logn_tensor.expand_as(query)
397
 
398
+ if (
399
+ self.use_flash_attn
400
+ and flash_attn_unpadded_func is not None
401
+ and not self.is_fp32
402
+ and query.is_cuda
403
+ ):
404
  q, k, v = query, key, value
405
  context_layer = self.core_attention_flash(q, k, v)
406
 
 
421
  attn_output = self.c_proj(context_layer)
422
  outputs = (attn_output, present)
423
  if output_attentions:
424
+ if (
425
+ self.use_flash_attn
426
+ and flash_attn_unpadded_func is not None
427
+ and not self.is_fp32
428
+ ):
429
  raise ValueError("Cannot output attentions while using flash-attn")
430
  else:
431
  outputs += (attn_weight,)
 
777
  super().__init__(config)
778
  self.transformer = QWenModel(config)
779
  self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
780
+ assert not (
781
+ config.bf16 and config.fp16
782
+ ), "In config, bf16 and fp16 cannot both be true"
783
  if config.bf16:
784
  self.transformer.bfloat16()
785
  self.lm_head.bfloat16()
 
958
  generation_config: Optional[GenerationConfig] = None,
959
  logits_processor: Optional[LogitsProcessorList] = None,
960
  stopping_criteria: Optional[StoppingCriteriaList] = None,
961
+ prefix_allowed_tokens_fn: Optional[
962
+ Callable[[int, torch.Tensor], List[int]]
963
+ ] = None,
964
  synced_gpus: Optional[bool] = None,
965
  streamer: Optional["BaseStreamer"] = None,
966
  **kwargs,
967
  ) -> Union[GenerateOutput, torch.LongTensor]:
968
  # Process stop_words_ids.
969
+ stop_words_ids = kwargs.pop("stop_words_ids", None)
970
  if stop_words_ids is None and generation_config is not None:
971
+ stop_words_ids = getattr(generation_config, "stop_words_ids", None)
972
  if stop_words_ids is None:
973
+ stop_words_ids = getattr(self.generation_config, "stop_words_ids", None)
974
 
975
  if stop_words_ids is not None:
976
  stop_words_logits_processor = StopWordsLogitsProcessor(
977
+ stop_words_ids=stop_words_ids,
978
+ eos_token_id=self.generation_config.eos_token_id,
979
+ )
980
  if logits_processor is None:
981
  logits_processor = LogitsProcessorList([stop_words_logits_processor])
982
  else:
 
1011
  seqlen = max_seq_len + offset
1012
  if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
1013
  base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
1014
+ self.inv_freq = 1.0 / (
1015
+ base
1016
+ ** (
1017
+ torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
1018
+ / self.dim
1019
+ )
1020
+ )
1021
  self._seq_len_cached = seqlen
1022
  self._ntk_alpha_cached = ntk_alpha
1023
  seq = torch.arange(seqlen, device=self.inv_freq.device)
 
1067
  return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
1068
 
1069
  def forward(self, x):
1070
+ if rms_norm is not None and x.is_cuda:
1071
  return rms_norm(x, self.weight, self.eps)
1072
  else:
1073
  output = self._norm(x.float()).type_as(x)
1074
+ return output * self.weight