Aman K commited on
Commit
932c11c
1 Parent(s): 2e5979b
events.out.tfevents.1625389806.t1v-n-9df4ce0e-w-0.369792.3.v2 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f3475f2b242b2380cbb03e47e425ac8cd5e01f9bd9650050d1210fec071653ac
3
- size 5624370
 
 
 
 
flax_model.msgpack DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:005d7634c02e588157e7650ac587bf2436cd58e55526f0df87d23c7e4def2d35
3
- size 498796983
 
 
 
 
run_mlm_flax.py CHANGED
@@ -611,7 +611,7 @@ if __name__ == "__main__":
611
  model_inputs = shard(model_inputs.data)
612
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
613
  train_metrics.append(train_metric)
614
- if save_checkpoint and (train_metric['loss'] < 5.).all():
615
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
616
  model.save_pretrained(
617
  '/home/khandelia1000/checkpoints/',
 
611
  model_inputs = shard(model_inputs.data)
612
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
613
  train_metrics.append(train_metric)
614
+ if save_checkpoint and (train_metric['loss'] < 1.).all():
615
  params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
616
  model.save_pretrained(
617
  '/home/khandelia1000/checkpoints/',