Vivek commited on
Commit
0972377
1 Parent(s): 13a2a4f

slight update

Browse files
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. src/.DS_Store +0 -0
  3. src/gptneo_story.py +4 -2
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/gptneo_story.py CHANGED
@@ -123,7 +123,8 @@ def train_step(state,batch,dropout_rng):
123
  #Added.
124
  logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
125
  accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
126
- metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
 
127
  return new_state,metrics,new_dropout_rng
128
 
129
  parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
@@ -135,7 +136,8 @@ def eval_step(state, batch):
135
  predictions=state.logits_function(logits)
136
  eval_accuracy=jnp.equal(predictions,targets)
137
  #eval_acc=jnp.equal(predictions,targets)
138
- metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
 
139
  #return state.logits_function(logits) #(8,4)
140
  return targets,predictions,metrics
141
 
 
123
  #Added.
124
  logits=new_state.apply_fn(**batch,params=new_state.params,dropout_rng=dropout_rng,train=True)[0]
125
  accuracy=jnp.equal(jnp.argmax(logits,axis=-1),targets)
126
+ #metrics=jax.lax.pmean({"loss":loss,"learning_rate":learning_rate_function(state.step),'accuracy':accuracy},axis_name="batch")
127
+ metrics=jax.lax.pmean({"loss":jax.device_get(loss),"learning_rate":jax.device_get(learning_rate_function(state.step)),'accuracy':jax.device_get(accuracy)},axis_name="batch")
128
  return new_state,metrics,new_dropout_rng
129
 
130
  parallel_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,))
 
136
  predictions=state.logits_function(logits)
137
  eval_accuracy=jnp.equal(predictions,targets)
138
  #eval_acc=jnp.equal(predictions,targets)
139
+ metrics=jax.lax.pmean({"loss":jax.device_get(loss),'accuracy':jax.device_get(eval_accuracy)},axis_name="batch")
140
+ #metrics=jax.lax.pmean({"loss":loss,'accuracy':eval_accuracy},axis_name="batch")
141
  #return state.logits_function(logits) #(8,4)
142
  return targets,predictions,metrics
143