Aman K
commited on
Commit
•
932c11c
1
Parent(s):
2e5979b
Fix on lr
Browse files
events.out.tfevents.1625389806.t1v-n-9df4ce0e-w-0.369792.3.v2
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:f3475f2b242b2380cbb03e47e425ac8cd5e01f9bd9650050d1210fec071653ac
|
3 |
-
size 5624370
|
|
|
|
|
|
|
|
flax_model.msgpack
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:005d7634c02e588157e7650ac587bf2436cd58e55526f0df87d23c7e4def2d35
|
3 |
-
size 498796983
|
|
|
|
|
|
|
|
run_mlm_flax.py
CHANGED
@@ -611,7 +611,7 @@ if __name__ == "__main__":
|
|
611 |
model_inputs = shard(model_inputs.data)
|
612 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
613 |
train_metrics.append(train_metric)
|
614 |
-
if save_checkpoint and (train_metric['loss'] <
|
615 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
616 |
model.save_pretrained(
|
617 |
'/home/khandelia1000/checkpoints/',
|
|
|
611 |
model_inputs = shard(model_inputs.data)
|
612 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
613 |
train_metrics.append(train_metric)
|
614 |
+
if save_checkpoint and (train_metric['loss'] < 1.).all():
|
615 |
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
616 |
model.save_pretrained(
|
617 |
'/home/khandelia1000/checkpoints/',
|