3e-7 lr update
Browse files- .DS_Store +0 -0
- src/gptneo_story.py +1 -1
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
src/gptneo_story.py
CHANGED
@@ -75,7 +75,7 @@ print('The overall batch size (both for training and eval) is', total_batch_size
|
|
75 |
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
|
76 |
num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
|
77 |
|
78 |
-
learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=
|
79 |
|
80 |
class TrainState(train_state.TrainState):
|
81 |
logits_function:Callable=flax.struct.field(pytree_node=False)
|
|
|
75 |
num_train_steps = len(train_dataset) // total_batch_size * num_train_epochs
|
76 |
num_validation_steps=len(validation_dataset)//total_batch_size*num_train_epochs
|
77 |
|
78 |
+
learning_rate_function = optax.linear_schedule(init_value=learning_rate, end_value=3e-7, transition_steps=num_train_steps)
|
79 |
|
80 |
class TrainState(train_state.TrainState):
|
81 |
logits_function:Callable=flax.struct.field(pytree_node=False)
|