Vivek commited on
Commit
aa11173
1 Parent(s): e868809

Saving weights of epoch 1 at step 92

Browse files
__pycache__/model_file.cpython-38.pyc CHANGED
Binary files a/__pycache__/model_file.cpython-38.pyc and b/__pycache__/model_file.cpython-38.pyc differ
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7566bc459faaabe05b6f642a97a367b444365761679f1b5dd13312d70b413601
3
  size 1419367919
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51d86bd352715e1623b69a8451f8c752c314bb6cf7669a5d9bb2f7589261d8c3
3
  size 1419367919
model_file.py CHANGED
@@ -190,7 +190,7 @@ class FlaxGPT2ForMultipleChoiceModule(nn.Module):
190
  dtype: jnp.dtype = jnp.float32
191
  def setup(self):
192
  self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
193
- self.dropout = nn.Dropout(rate=0.3)
194
  self.classifier = nn.Dense(4, dtype=self.dtype)
195
 
196
  def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
190
  dtype: jnp.dtype = jnp.float32
191
  def setup(self):
192
  self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
193
+ self.dropout = nn.Dropout(rate=0.2)
194
  self.classifier = nn.Dense(4, dtype=self.dtype)
195
 
196
  def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
results_tensorboard/events.out.tfevents.1626339960.t1v-n-8cb15980-w-0.776261.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:322d306ebcf9d805c02057d9c2c761e63a887231db3d665df7d6dc88bed92174
3
+ size 25038
train.py CHANGED
@@ -74,7 +74,7 @@ def main():
74
  per_device_batch_size=4
75
  seed=0
76
  num_train_epochs=3
77
- learning_rate=4e-5
78
 
79
 
80
  total_batch_size = per_device_batch_size * jax.local_device_count()
74
  per_device_batch_size=4
75
  seed=0
76
  num_train_epochs=3
77
+ learning_rate=2e-5
78
 
79
 
80
  total_batch_size = per_device_batch_size * jax.local_device_count()