Spaces:
Starting
on
A100
Starting
on
A100
nroggendorff
commited on
Commit
•
837ed4a
1
Parent(s):
67fdfd0
Update train.py
Browse files
train.py
CHANGED
@@ -18,7 +18,8 @@ OUTPUT_REPO = "smallama"
|
|
18 |
FP16 = True
|
19 |
WARMUP_STEPS = 500
|
20 |
DECAY = 0.01
|
21 |
-
|
|
|
22 |
PUSH_TO_HUB = True
|
23 |
|
24 |
def load_data():
|
@@ -101,9 +102,9 @@ def train_model(model, tokenizer, dataset, push):
|
|
101 |
optim="adamw_torch",
|
102 |
warmup_steps=WARMUP_STEPS,
|
103 |
weight_decay=DECAY,
|
104 |
-
gradient_accumulation_steps=
|
105 |
fp16=FP16,
|
106 |
-
|
107 |
)
|
108 |
|
109 |
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|
|
|
18 |
FP16 = True
|
19 |
WARMUP_STEPS = 500
|
20 |
DECAY = 0.01
|
21 |
+
GRADIENT_ACCUMULATION_STEPS = 4
|
22 |
+
CLIPPING = 1.0
|
23 |
PUSH_TO_HUB = True
|
24 |
|
25 |
def load_data():
|
|
|
102 |
optim="adamw_torch",
|
103 |
warmup_steps=WARMUP_STEPS,
|
104 |
weight_decay=DECAY,
|
105 |
+
gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
|
106 |
fp16=FP16,
|
107 |
+
max_grad_norm=CLIPPING
|
108 |
)
|
109 |
|
110 |
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
|