Text Generation
Transformers
Safetensors
English
stablelm
causal-lm
Eval Results
Inference Endpoints
8 papers
jon-tow commited on
Commit
b6e4fc1
1 Parent(s): c24bc36

feat: add dropout support

Browse files
configuration_stablelm_epoch.py CHANGED
@@ -64,6 +64,8 @@ class StableLMEpochConfig(PretrainedConfig):
64
  (not used by all models). Only relevant if `config.is_decoder=True`.
65
  tie_word_embeddings(`bool`, *optional*, defaults to `False`):
66
  Whether to tie weight embeddings
 
 
67
  """
68
  model_type = "stablelm_epoch"
69
  keys_to_ignore_at_inference = ["past_key_values"]
@@ -86,6 +88,7 @@ class StableLMEpochConfig(PretrainedConfig):
86
  bos_token_id=0,
87
  eos_token_id=2,
88
  tie_word_embeddings=False,
 
89
  **kwargs,
90
  ):
91
  self.vocab_size = vocab_size
@@ -102,6 +105,7 @@ class StableLMEpochConfig(PretrainedConfig):
102
  self.norm_eps = norm_eps
103
  self.use_cache = use_cache
104
  self.tie_word_embeddings = tie_word_embeddings
 
105
  super().__init__(
106
  bos_token_id=bos_token_id,
107
  eos_token_id=eos_token_id,
 
64
  (not used by all models). Only relevant if `config.is_decoder=True`.
65
  tie_word_embeddings(`bool`, *optional*, defaults to `False`):
66
  Whether to tie weight embeddings
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
  """
70
  model_type = "stablelm_epoch"
71
  keys_to_ignore_at_inference = ["past_key_values"]
 
88
  bos_token_id=0,
89
  eos_token_id=2,
90
  tie_word_embeddings=False,
91
+ attention_dropout: float = 0.0,
92
  **kwargs,
93
  ):
94
  self.vocab_size = vocab_size
 
105
  self.norm_eps = norm_eps
106
  self.use_cache = use_cache
107
  self.tie_word_embeddings = tie_word_embeddings
108
+ self.attention_dropout = attention_dropout
109
  super().__init__(
110
  bos_token_id=bos_token_id,
111
  eos_token_id=eos_token_id,
modeling_stablelm_epoch.py CHANGED
@@ -191,6 +191,7 @@ class Attention(nn.Module):
191
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
  self.max_position_embeddings = config.max_position_embeddings
193
  self.is_causal = True
 
194
 
195
  if (self.head_dim * self.num_heads) != self.hidden_size:
196
  raise ValueError(
@@ -274,6 +275,7 @@ class Attention(nn.Module):
274
 
275
  # Upcast attention to fp32
276
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
277
  attn_output = torch.matmul(attn_weights, value_states)
278
 
279
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
 
191
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
192
  self.max_position_embeddings = config.max_position_embeddings
193
  self.is_causal = True
194
+ self.attention_dropout = config.attention_dropout
195
 
196
  if (self.head_dim * self.num_heads) != self.hidden_size:
197
  raise ValueError(
 
275
 
276
  # Upcast attention to fp32
277
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
278
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
279
  attn_output = torch.matmul(attn_weights, value_states)
280
 
281
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):