yangapku commited on
Commit
a6ca629
1 Parent(s): 6837ccb

softmax_in_fp32

Browse files
Files changed (2) hide show
  1. configuration_qwen.py +4 -2
  2. modeling_qwen.py +5 -1
configuration_qwen.py CHANGED
@@ -37,6 +37,7 @@ class QWenConfig(PretrainedConfig):
37
  tie_word_embeddings=False,
38
  use_cache_quantization=False,
39
  use_cache_kernel=False,
 
40
  **kwargs,
41
  ):
42
  self.vocab_size = vocab_size
@@ -61,8 +62,9 @@ class QWenConfig(PretrainedConfig):
61
  self.use_logn_attn = use_logn_attn
62
  self.use_flash_attn = use_flash_attn
63
  self.no_bias = no_bias
64
- self.use_cache_quantization=use_cache_quantization
65
- self.use_cache_kernel=use_cache_kernel
 
66
  super().__init__(
67
  tie_word_embeddings=tie_word_embeddings,
68
  **kwargs
 
37
  tie_word_embeddings=False,
38
  use_cache_quantization=False,
39
  use_cache_kernel=False,
40
+ softmax_in_fp32=False,
41
  **kwargs,
42
  ):
43
  self.vocab_size = vocab_size
 
62
  self.use_logn_attn = use_logn_attn
63
  self.use_flash_attn = use_flash_attn
64
  self.no_bias = no_bias
65
+ self.use_cache_quantization = use_cache_quantization
66
+ self.use_cache_kernel = use_cache_kernel
67
+ self.softmax_in_fp32 = softmax_in_fp32
68
  super().__init__(
69
  tie_word_embeddings=tie_word_embeddings,
70
  **kwargs
modeling_qwen.py CHANGED
@@ -280,6 +280,7 @@ class QWenAttention(nn.Module):
280
  self.register_buffer("logn_tensor", logn_tensor, persistent=False)
281
 
282
  self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
 
283
  self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
284
  self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
285
  cache_dtype = torch.float
@@ -346,7 +347,10 @@ class QWenAttention(nn.Module):
346
  if attention_mask is not None:
347
  attn_weights = attn_weights + attention_mask
348
 
349
- attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
 
 
 
350
 
351
  attn_weights = attn_weights.type(query.dtype)
352
  attn_weights = self.attn_dropout(attn_weights)
 
280
  self.register_buffer("logn_tensor", logn_tensor, persistent=False)
281
 
282
  self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
283
+ self.softmax_in_fp32 = config.softmax_in_fp32 if hasattr(config, 'softmax_in_fp32') else False
284
  self.use_cache_quantization = config.use_cache_quantization if hasattr(config, 'use_cache_quantization') else False
285
  self.use_cache_kernel = config.use_cache_kernel if hasattr(config,'use_cache_kernel') else False
286
  cache_dtype = torch.float
 
347
  if attention_mask is not None:
348
  attn_weights = attn_weights + attention_mask
349
 
350
+ if self.softmax_in_fp32:
351
+ attn_weights = nn.functional.softmax(attn_weights.float(), dim=-1)
352
+ else:
353
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
354
 
355
  attn_weights = attn_weights.type(query.dtype)
356
  attn_weights = self.attn_dropout(attn_weights)