Vivek commited on
Commit
2f0f3f3
1 Parent(s): b368398

Saving weights of epoch 1 at step 92

Browse files
__pycache__/model_file.cpython-38.pyc CHANGED
Binary files a/__pycache__/model_file.cpython-38.pyc and b/__pycache__/model_file.cpython-38.pyc differ
 
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:68265cb6ae648f60d66c66a3cbf9f29b7deeb2054a911c27a482ec8afb33b9c0
3
  size 1419367919
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f16abc176050411b46334568e4cda7b396760a303a5e489c0f5c009be496ae6d
3
  size 1419367919
model_file.py CHANGED
@@ -190,7 +190,7 @@ class FlaxGPT2ForMultipleChoiceModule(nn.Module):
190
  dtype: jnp.dtype = jnp.float32
191
  def setup(self):
192
  self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
193
- self.dropout = nn.Dropout(rate=0.2)
194
  self.classifier = nn.Dense(4, dtype=self.dtype)
195
 
196
  def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
@@ -223,4 +223,4 @@ class FlaxGPT2ForMultipleChoiceModule(nn.Module):
223
  return reshaped_logits
224
 
225
  class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):
226
- module_class = FlaxGPT2ForMultipleChoiceModule
 
190
  dtype: jnp.dtype = jnp.float32
191
  def setup(self):
192
  self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
193
+ self.dropout = nn.Dropout(rate=0.3)
194
  self.classifier = nn.Dense(4, dtype=self.dtype)
195
 
196
  def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
 
223
  return reshaped_logits
224
 
225
  class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):
226
+ module_class = FlaxGPT2ForMultipleChoiceModule
results_tensorboard/events.out.tfevents.1626338427.t1v-n-8cb15980-w-0.773256.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d6aa87397b55c231b84b2c1c25964400b46c0de5ebe44310c9571cdff77549e0
3
+ size 25038
train.py CHANGED
@@ -73,8 +73,8 @@ def main():
73
 
74
  per_device_batch_size=4
75
  seed=0
76
- num_train_epochs=5
77
- learning_rate=2e-5
78
 
79
 
80
  total_batch_size = per_device_batch_size * jax.local_device_count()
 
73
 
74
  per_device_batch_size=4
75
  seed=0
76
+ num_train_epochs=3
77
+ learning_rate=4e-5
78
 
79
 
80
  total_batch_size = per_device_batch_size * jax.local_device_count()