Text Generation
Transformers
Safetensors
Finnish
llama
finnish
conversational
text-generation-inference
aapot commited on
Commit
0b67ff4
1 Parent(s): 916632f

Update optimizers

Browse files
Files changed (3) hide show
  1. EasyLM/data.py +7 -0
  2. EasyLM/optimizers.py +47 -3
  3. pretrain_llama_3b.sh +2 -1
EasyLM/data.py CHANGED
@@ -153,6 +153,7 @@ class HuggingfaceDataset(object):
153
  config.start_seek_loc = 0
154
  config.tokens_count_at_start = 0
155
  config.batch_token_dtype = 'i4'
 
156
 
157
  if updates is not None:
158
  config.update(ConfigDict(updates).copy_and_resolve_references())
@@ -173,6 +174,8 @@ class HuggingfaceDataset(object):
173
  self._dataset_loc = self.config.start_seek_loc
174
  self._total_tokens = self.config.tokens_count_at_start
175
  self._index = 0
 
 
176
 
177
  def __iter__(self):
178
  if not self._eval_dataset and self._train_epochs > 0:
@@ -236,6 +239,10 @@ class HuggingfaceDataset(object):
236
  self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
237
  self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
238
  self._train_epochs = state_dict.get('epochs', 0)
 
 
 
 
239
 
240
  @property
241
  def seq_length(self):
 
153
  config.start_seek_loc = 0
154
  config.tokens_count_at_start = 0
155
  config.batch_token_dtype = 'i4'
156
+ config.reset_dataset_loc = False
157
 
158
  if updates is not None:
159
  config.update(ConfigDict(updates).copy_and_resolve_references())
 
174
  self._dataset_loc = self.config.start_seek_loc
175
  self._total_tokens = self.config.tokens_count_at_start
176
  self._index = 0
177
+ self.reset_dataset_loc = self.config.reset_dataset_loc
178
+
179
 
180
  def __iter__(self):
181
  if not self._eval_dataset and self._train_epochs > 0:
 
239
  self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
240
  self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
241
  self._train_epochs = state_dict.get('epochs', 0)
242
+ if self.reset_dataset_loc:
243
+ self._dataset_loc = 0
244
+ self._train_epochs = 0
245
+
246
 
247
  @property
248
  def seq_length(self):
EasyLM/optimizers.py CHANGED
@@ -205,8 +205,9 @@ class LionOptimizerFactory(object):
205
  config.init_lr = 0.0
206
  config.end_lr = 0.0001
207
  config.lr = 0.001
208
- config.lr_warmup_steps = 2000
209
- config.lr_decay_steps = 500000
 
210
  config.b1 = 0.9
211
  config.b2 = 0.98
212
  config.clip_gradient = 1.0
@@ -243,6 +244,43 @@ class LionOptimizerFactory(object):
243
  ],
244
  [config.lr_warmup_steps],
245
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  elif config.lr_schedule_type == "exponential_decay":
247
  learning_rate_schedule = optax.exponential_decay(
248
  init_value=config.lr,
@@ -252,8 +290,14 @@ class LionOptimizerFactory(object):
252
  staircase=False,
253
  end_value=config.end_lr,
254
  )
 
 
 
 
 
 
255
  else:
256
- raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", or "exponential_decay"')
257
 
258
  optimizer_info = dict(
259
  learning_rate_schedule=learning_rate_schedule,
 
205
  config.init_lr = 0.0
206
  config.end_lr = 0.0001
207
  config.lr = 0.001
208
+ config.lr_warmup_steps = 60000
209
+ config.lr_constant_steps = 840000
210
+ config.lr_decay_steps = 100000
211
  config.b1 = 0.9
212
  config.b2 = 0.98
213
  config.clip_gradient = 1.0
 
244
  ],
245
  [config.lr_warmup_steps],
246
  )
247
+ elif config.lr_schedule_type == "warmup_constant_linear_decay":
248
+ learning_rate_schedule = optax.join_schedules(
249
+ [
250
+ optax.linear_schedule(
251
+ init_value=config.init_lr,
252
+ end_value=config.lr,
253
+ transition_steps=config.lr_warmup_steps,
254
+ ),
255
+ optax.constant_schedule(config.lr),
256
+ optax.linear_schedule(
257
+ init_value=config.lr,
258
+ end_value=config.end_lr,
259
+ transition_steps=config.lr_decay_steps,
260
+ )
261
+ ],
262
+ [config.lr_warmup_steps, config.lr_constant_steps],
263
+ )
264
+ elif config.lr_schedule_type == "warmup_constant_exponential_decay":
265
+ learning_rate_schedule = optax.join_schedules(
266
+ [
267
+ optax.linear_schedule(
268
+ init_value=config.init_lr,
269
+ end_value=config.lr,
270
+ transition_steps=config.lr_warmup_steps,
271
+ ),
272
+ optax.constant_schedule(config.lr),
273
+ optax.exponential_decay(
274
+ init_value=config.lr,
275
+ transition_steps=config.lr_decay_steps,
276
+ decay_rate=config.lr_decay_rate,
277
+ transition_begin=0,
278
+ staircase=False,
279
+ end_value=config.end_lr,
280
+ )
281
+ ],
282
+ [config.lr_warmup_steps, config.lr_constant_steps],
283
+ )
284
  elif config.lr_schedule_type == "exponential_decay":
285
  learning_rate_schedule = optax.exponential_decay(
286
  init_value=config.lr,
 
290
  staircase=False,
291
  end_value=config.end_lr,
292
  )
293
+ elif config.lr_schedule_type == "linear_decay":
294
+ learning_rate_schedule = optax.linear_schedule(
295
+ init_value=config.lr,
296
+ end_value=config.end_lr,
297
+ transition_steps=config.lr_decay_steps,
298
+ )
299
  else:
300
+ raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
301
 
302
  optimizer_info = dict(
303
  learning_rate_schedule=learning_rate_schedule,
pretrain_llama_3b.sh CHANGED
@@ -23,10 +23,11 @@ python3 -m EasyLM.models.llama.llama_train \
23
  --tokenizer.vocab_file='tokenizer.model' \
24
  --optimizer.type='lion' \
25
  --optimizer.lion_optimizer.weight_decay=1.0 \
26
- --optimizer.lion_optimizer.lr_schedule_type='warmup_constant' \
27
  --optimizer.lion_optimizer.lr=1e-4 \
28
  --optimizer.lion_optimizer.end_lr=1e-5 \
29
  --optimizer.lion_optimizer.lr_warmup_steps=60000 \
 
30
  --optimizer.lion_optimizer.lr_decay_steps=100000 \
31
  --optimizer.lion_optimizer.bf16_momentum=True \
32
  --train_dataset.type='huggingface' \
 
23
  --tokenizer.vocab_file='tokenizer.model' \
24
  --optimizer.type='lion' \
25
  --optimizer.lion_optimizer.weight_decay=1.0 \
26
+ --optimizer.lion_optimizer.lr_schedule_type='warmup_constant_linear_decay' \
27
  --optimizer.lion_optimizer.lr=1e-4 \
28
  --optimizer.lion_optimizer.end_lr=1e-5 \
29
  --optimizer.lion_optimizer.lr_warmup_steps=60000 \
30
+ --optimizer.lion_optimizer.lr_constant_steps=900000 \
31
  --optimizer.lion_optimizer.lr_decay_steps=100000 \
32
  --optimizer.lion_optimizer.bf16_momentum=True \
33
  --train_dataset.type='huggingface' \