mattricesound commited on
Commit
5570d2c
·
1 Parent(s): 1ff07dc

Fix LRScheduler and ckpt_path

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. cfg/config.yaml +1 -0
  3. remfx/models.py +3 -2
  4. scripts/train.py +1 -1
README.md CHANGED
@@ -22,7 +22,7 @@ Models and effects detailed below.
22
 
23
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
24
 
25
- Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=-1`
26
 
27
  ### Current Models
28
  - `umx`
 
22
 
23
  To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
24
 
25
+ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=1`
26
 
27
  ### Current Models
28
  - `umx`
cfg/config.yaml CHANGED
@@ -77,6 +77,7 @@ trainer:
77
  enable_model_summary: False
78
  log_every_n_steps: 1 # Logs metrics every N batches
79
  accumulate_grad_batches: 1
 
80
  accelerator: null
81
  devices: 1
82
  gradient_clip_val: 10.0
 
77
  enable_model_summary: False
78
  log_every_n_steps: 1 # Logs metrics every N batches
79
  accumulate_grad_batches: 1
80
+ deterministic: True
81
  accelerator: null
82
  devices: 1
83
  gradient_clip_val: 10.0
remfx/models.py CHANGED
@@ -63,8 +63,9 @@ class RemFXModel(pl.LightningModule):
63
  optimizer,
64
  optimizer_idx,
65
  optimizer_closure,
66
- on_tpu=False,
67
- using_lbfgs=False,
 
68
  ):
69
  # update params
70
  optimizer.step(closure=optimizer_closure)
 
63
  optimizer,
64
  optimizer_idx,
65
  optimizer_closure,
66
+ on_tpu,
67
+ using_native_amp,
68
+ using_lbfgs,
69
  ):
70
  # update params
71
  optimizer.step(closure=optimizer_closure)
scripts/train.py CHANGED
@@ -42,7 +42,7 @@ def main(cfg: DictConfig):
42
  summary = ModelSummary(model)
43
  print(summary)
44
  trainer.fit(model=model, datamodule=datamodule)
45
- trainer.test(model=model, datamodule=datamodule)
46
 
47
 
48
  if __name__ == "__main__":
 
42
  summary = ModelSummary(model)
43
  print(summary)
44
  trainer.fit(model=model, datamodule=datamodule)
45
+ trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
46
 
47
 
48
  if __name__ == "__main__":