slight update
Browse files- .DS_Store +0 -0
- src/.DS_Store +0 -0
- src/gptneo_story.py +4 -2
.DS_Store
CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
|
|
src/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
src/gptneo_story.py
CHANGED
@@ -123,7 +123,8 @@ def train_step(state,batch,dropout_rng):
|
|
123 |
#Added.
|
124 |
logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
|
125 |
accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
|
126 |
-
metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
|
|
|
127 |
return new_state,metrics,new_dropout_rng
|
128 |
|
129 |
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
|
@@ -135,7 +136,8 @@ def eval_step(state, batch):
|
|
135 |
predictions=state.logits_function(logits)
|
136 |
eval_accuracy=jnp.equal(predictions,targets)
|
137 |
#eval_acc=jnp.equal(predictions,targets)
|
138 |
-
metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
|
|
|
139 |
#return state.logits_function(logits) #(8,4)
|
140 |
return targets,predictions,metrics
|
141 |
|
|
|
123 |
#Added.
|
124 |
logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
|
125 |
accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
|
126 |
+
#metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
|
127 |
+
metrics=jax.lax.pmean({"loss":jax.device_get(loss),"learning_rate":jax.device_get(learning_rate_function(state.step)),'accuracy':jax.device_get(accuracy)},axis_name="batch")
|
128 |
return new_state,metrics,new_dropout_rng
|
129 |
|
130 |
parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
|
|
|
136 |
predictions=state.logits_function(logits)
|
137 |
eval_accuracy=jnp.equal(predictions,targets)
|
138 |
#eval_acc=jnp.equal(predictions,targets)
|
139 |
+
metrics=jax.lax.pmean({"loss":jax.device_get(loss),'accuracy':jax.device_get(eval_accuracy)},axis_name="batch")
|
140 |
+
#metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
|
141 |
#return state.logits_function(logits) #(8,4)
|
142 |
return targets,predictions,metrics
|
143 |
|