Saving weights of epoch 1 at step 92
Browse files
__pycache__/model_file.cpython-38.pyc
CHANGED
Binary files a/__pycache__/model_file.cpython-38.pyc and b/__pycache__/model_file.cpython-38.pyc differ
|
|
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1419367919
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f16abc176050411b46334568e4cda7b396760a303a5e489c0f5c009be496ae6d
|
3 |
size 1419367919
|
model_file.py
CHANGED
@@ -190,7 +190,7 @@ class FlaxGPT2ForMultipleChoiceModule(nn.Module):
|
|
190 |
dtype: jnp.dtype = jnp.float32
|
191 |
def setup(self):
|
192 |
self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
|
193 |
-
self.dropout = nn.Dropout(rate=0.
|
194 |
self.classifier = nn.Dense(4, dtype=self.dtype)
|
195 |
|
196 |
def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
|
@@ -223,4 +223,4 @@ class FlaxGPT2ForMultipleChoiceModule(nn.Module):
|
|
223 |
return reshaped_logits
|
224 |
|
225 |
class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):
|
226 |
-
module_class = FlaxGPT2ForMultipleChoiceModule
|
|
|
190 |
dtype: jnp.dtype = jnp.float32
|
191 |
def setup(self):
|
192 |
self.transformer = FlaxGPT2Module(config=self.config, dtype=self.dtype)
|
193 |
+
self.dropout = nn.Dropout(rate=0.3)
|
194 |
self.classifier = nn.Dense(4, dtype=self.dtype)
|
195 |
|
196 |
def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):
|
|
|
223 |
return reshaped_logits
|
224 |
|
225 |
class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):
|
226 |
+
module_class = FlaxGPT2ForMultipleChoiceModule
|
results_tensorboard/events.out.tfevents.1626338427.t1v-n-8cb15980-w-0.773256.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d6aa87397b55c231b84b2c1c25964400b46c0de5ebe44310c9571cdff77549e0
|
3 |
+
size 25038
|
train.py
CHANGED
@@ -73,8 +73,8 @@ def main():
|
|
73 |
|
74 |
per_device_batch_size=4
|
75 |
seed=0
|
76 |
-
num_train_epochs=
|
77 |
-
learning_rate=
|
78 |
|
79 |
|
80 |
total_batch_size = per_device_batch_size * jax.local_device_count()
|
|
|
73 |
|
74 |
per_device_batch_size=4
|
75 |
seed=0
|
76 |
+
num_train_epochs=3
|
77 |
+
learning_rate=4e-5
|
78 |
|
79 |
|
80 |
total_batch_size = per_device_batch_size * jax.local_device_count()
|