aapot
commited on
Commit
•
0b67ff4
1
Parent(s):
916632f
Update optimizers
Browse files- EasyLM/data.py +7 -0
- EasyLM/optimizers.py +47 -3
- pretrain_llama_3b.sh +2 -1
EasyLM/data.py
CHANGED
@@ -153,6 +153,7 @@ class HuggingfaceDataset(object):
|
|
153 |
config.start_seek_loc = 0
|
154 |
config.tokens_count_at_start = 0
|
155 |
config.batch_token_dtype = 'i4'
|
|
|
156 |
|
157 |
if updates is not None:
|
158 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
@@ -173,6 +174,8 @@ class HuggingfaceDataset(object):
|
|
173 |
self._dataset_loc = self.config.start_seek_loc
|
174 |
self._total_tokens = self.config.tokens_count_at_start
|
175 |
self._index = 0
|
|
|
|
|
176 |
|
177 |
def __iter__(self):
|
178 |
if not self._eval_dataset and self._train_epochs > 0:
|
@@ -236,6 +239,10 @@ class HuggingfaceDataset(object):
|
|
236 |
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
237 |
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
238 |
self._train_epochs = state_dict.get('epochs', 0)
|
|
|
|
|
|
|
|
|
239 |
|
240 |
@property
|
241 |
def seq_length(self):
|
|
|
153 |
config.start_seek_loc = 0
|
154 |
config.tokens_count_at_start = 0
|
155 |
config.batch_token_dtype = 'i4'
|
156 |
+
config.reset_dataset_loc = False
|
157 |
|
158 |
if updates is not None:
|
159 |
config.update(ConfigDict(updates).copy_and_resolve_references())
|
|
|
174 |
self._dataset_loc = self.config.start_seek_loc
|
175 |
self._total_tokens = self.config.tokens_count_at_start
|
176 |
self._index = 0
|
177 |
+
self.reset_dataset_loc = self.config.reset_dataset_loc
|
178 |
+
|
179 |
|
180 |
def __iter__(self):
|
181 |
if not self._eval_dataset and self._train_epochs > 0:
|
|
|
239 |
self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
|
240 |
self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
|
241 |
self._train_epochs = state_dict.get('epochs', 0)
|
242 |
+
if self.reset_dataset_loc:
|
243 |
+
self._dataset_loc = 0
|
244 |
+
self._train_epochs = 0
|
245 |
+
|
246 |
|
247 |
@property
|
248 |
def seq_length(self):
|
EasyLM/optimizers.py
CHANGED
@@ -205,8 +205,9 @@ class LionOptimizerFactory(object):
|
|
205 |
config.init_lr = 0.0
|
206 |
config.end_lr = 0.0001
|
207 |
config.lr = 0.001
|
208 |
-
config.lr_warmup_steps =
|
209 |
-
config.
|
|
|
210 |
config.b1 = 0.9
|
211 |
config.b2 = 0.98
|
212 |
config.clip_gradient = 1.0
|
@@ -243,6 +244,43 @@ class LionOptimizerFactory(object):
|
|
243 |
],
|
244 |
[config.lr_warmup_steps],
|
245 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
elif config.lr_schedule_type == "exponential_decay":
|
247 |
learning_rate_schedule = optax.exponential_decay(
|
248 |
init_value=config.lr,
|
@@ -252,8 +290,14 @@ class LionOptimizerFactory(object):
|
|
252 |
staircase=False,
|
253 |
end_value=config.end_lr,
|
254 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
else:
|
256 |
-
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant",
|
257 |
|
258 |
optimizer_info = dict(
|
259 |
learning_rate_schedule=learning_rate_schedule,
|
|
|
205 |
config.init_lr = 0.0
|
206 |
config.end_lr = 0.0001
|
207 |
config.lr = 0.001
|
208 |
+
config.lr_warmup_steps = 60000
|
209 |
+
config.lr_constant_steps = 840000
|
210 |
+
config.lr_decay_steps = 100000
|
211 |
config.b1 = 0.9
|
212 |
config.b2 = 0.98
|
213 |
config.clip_gradient = 1.0
|
|
|
244 |
],
|
245 |
[config.lr_warmup_steps],
|
246 |
)
|
247 |
+
elif config.lr_schedule_type == "warmup_constant_linear_decay":
|
248 |
+
learning_rate_schedule = optax.join_schedules(
|
249 |
+
[
|
250 |
+
optax.linear_schedule(
|
251 |
+
init_value=config.init_lr,
|
252 |
+
end_value=config.lr,
|
253 |
+
transition_steps=config.lr_warmup_steps,
|
254 |
+
),
|
255 |
+
optax.constant_schedule(config.lr),
|
256 |
+
optax.linear_schedule(
|
257 |
+
init_value=config.lr,
|
258 |
+
end_value=config.end_lr,
|
259 |
+
transition_steps=config.lr_decay_steps,
|
260 |
+
)
|
261 |
+
],
|
262 |
+
[config.lr_warmup_steps, config.lr_constant_steps],
|
263 |
+
)
|
264 |
+
elif config.lr_schedule_type == "warmup_constant_exponential_decay":
|
265 |
+
learning_rate_schedule = optax.join_schedules(
|
266 |
+
[
|
267 |
+
optax.linear_schedule(
|
268 |
+
init_value=config.init_lr,
|
269 |
+
end_value=config.lr,
|
270 |
+
transition_steps=config.lr_warmup_steps,
|
271 |
+
),
|
272 |
+
optax.constant_schedule(config.lr),
|
273 |
+
optax.exponential_decay(
|
274 |
+
init_value=config.lr,
|
275 |
+
transition_steps=config.lr_decay_steps,
|
276 |
+
decay_rate=config.lr_decay_rate,
|
277 |
+
transition_begin=0,
|
278 |
+
staircase=False,
|
279 |
+
end_value=config.end_lr,
|
280 |
+
)
|
281 |
+
],
|
282 |
+
[config.lr_warmup_steps, config.lr_constant_steps],
|
283 |
+
)
|
284 |
elif config.lr_schedule_type == "exponential_decay":
|
285 |
learning_rate_schedule = optax.exponential_decay(
|
286 |
init_value=config.lr,
|
|
|
290 |
staircase=False,
|
291 |
end_value=config.end_lr,
|
292 |
)
|
293 |
+
elif config.lr_schedule_type == "linear_decay":
|
294 |
+
learning_rate_schedule = optax.linear_schedule(
|
295 |
+
init_value=config.lr,
|
296 |
+
end_value=config.end_lr,
|
297 |
+
transition_steps=config.lr_decay_steps,
|
298 |
+
)
|
299 |
else:
|
300 |
+
raise ValueError('config.lr_schedule_type must be "warmup_cosine_decay_schedule", "warmup_constant", "warmup_constant_linear_decay", "warmup_constant_exponential_decay", "exponential_decay" or "linear_decay"')
|
301 |
|
302 |
optimizer_info = dict(
|
303 |
learning_rate_schedule=learning_rate_schedule,
|
pretrain_llama_3b.sh
CHANGED
@@ -23,10 +23,11 @@ python3 -m EasyLM.models.llama.llama_train \
|
|
23 |
--tokenizer.vocab_file='tokenizer.model' \
|
24 |
--optimizer.type='lion' \
|
25 |
--optimizer.lion_optimizer.weight_decay=1.0 \
|
26 |
-
--optimizer.lion_optimizer.lr_schedule_type='
|
27 |
--optimizer.lion_optimizer.lr=1e-4 \
|
28 |
--optimizer.lion_optimizer.end_lr=1e-5 \
|
29 |
--optimizer.lion_optimizer.lr_warmup_steps=60000 \
|
|
|
30 |
--optimizer.lion_optimizer.lr_decay_steps=100000 \
|
31 |
--optimizer.lion_optimizer.bf16_momentum=True \
|
32 |
--train_dataset.type='huggingface' \
|
|
|
23 |
--tokenizer.vocab_file='tokenizer.model' \
|
24 |
--optimizer.type='lion' \
|
25 |
--optimizer.lion_optimizer.weight_decay=1.0 \
|
26 |
+
--optimizer.lion_optimizer.lr_schedule_type='warmup_constant_linear_decay' \
|
27 |
--optimizer.lion_optimizer.lr=1e-4 \
|
28 |
--optimizer.lion_optimizer.end_lr=1e-5 \
|
29 |
--optimizer.lion_optimizer.lr_warmup_steps=60000 \
|
30 |
+
--optimizer.lion_optimizer.lr_constant_steps=900000 \
|
31 |
--optimizer.lion_optimizer.lr_decay_steps=100000 \
|
32 |
--optimizer.lion_optimizer.bf16_momentum=True \
|
33 |
--train_dataset.type='huggingface' \
|