black formatting
Browse files- scripts/finetune.py +3 -1
- tests/test_validation.py +3 -1
scripts/finetune.py
CHANGED
@@ -152,7 +152,9 @@ def train(
|
|
152 |
validate_config(cfg)
|
153 |
|
154 |
# setup some derived config / hyperparams
|
155 |
-
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
|
|
|
|
156 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
157 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
158 |
choose_device(cfg)
|
|
|
152 |
validate_config(cfg)
|
153 |
|
154 |
# setup some derived config / hyperparams
|
155 |
+
cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
|
156 |
+
cfg.batch_size // cfg.micro_batch_size
|
157 |
+
)
|
158 |
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
159 |
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
160 |
choose_device(cfg)
|
tests/test_validation.py
CHANGED
@@ -126,7 +126,9 @@ class ValidationTest(unittest.TestCase):
|
|
126 |
}
|
127 |
)
|
128 |
|
129 |
-
with pytest.raises(
|
|
|
|
|
130 |
validate_config(cfg)
|
131 |
|
132 |
cfg = DictDefault(
|
|
|
126 |
}
|
127 |
)
|
128 |
|
129 |
+
with pytest.raises(
|
130 |
+
ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
|
131 |
+
):
|
132 |
validate_config(cfg)
|
133 |
|
134 |
cfg = DictDefault(
|