winglian commited on
Commit
6fa40bf
1 Parent(s): 3aad5f3

black formatting

Browse files
Files changed (2) hide show
  1. scripts/finetune.py +3 -1
  2. tests/test_validation.py +3 -1
scripts/finetune.py CHANGED
@@ -152,7 +152,9 @@ def train(
152
  validate_config(cfg)
153
 
154
  # setup some derived config / hyperparams
155
- cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (cfg.batch_size // cfg.micro_batch_size)
 
 
156
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
157
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
158
  choose_device(cfg)
 
152
  validate_config(cfg)
153
 
154
  # setup some derived config / hyperparams
155
+ cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps or (
156
+ cfg.batch_size // cfg.micro_batch_size
157
+ )
158
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
159
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
160
  choose_device(cfg)
tests/test_validation.py CHANGED
@@ -126,7 +126,9 @@ class ValidationTest(unittest.TestCase):
126
  }
127
  )
128
 
129
- with pytest.raises(ValueError, match=r".*gradient_accumulation_steps or batch_size.*"):
 
 
130
  validate_config(cfg)
131
 
132
  cfg = DictDefault(
 
126
  }
127
  )
128
 
129
+ with pytest.raises(
130
+ ValueError, match=r".*gradient_accumulation_steps or batch_size.*"
131
+ ):
132
  validate_config(cfg)
133
 
134
  cfg = DictDefault(