Spaces:
Runtime error
Runtime error
mattricesound
commited on
Commit
·
5570d2c
1
Parent(s):
1ff07dc
Fix LRScheduler and ckpt_path
Browse files- README.md +1 -1
- cfg/config.yaml +1 -0
- remfx/models.py +3 -2
- scripts/train.py +1 -1
README.md
CHANGED
@@ -22,7 +22,7 @@ Models and effects detailed below.
|
|
22 |
|
23 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
24 |
|
25 |
-
Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices
|
26 |
|
27 |
### Current Models
|
28 |
- `umx`
|
|
|
22 |
|
23 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
24 |
|
25 |
+
Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=1`
|
26 |
|
27 |
### Current Models
|
28 |
- `umx`
|
cfg/config.yaml
CHANGED
@@ -77,6 +77,7 @@ trainer:
|
|
77 |
enable_model_summary: False
|
78 |
log_every_n_steps: 1 # Logs metrics every N batches
|
79 |
accumulate_grad_batches: 1
|
|
|
80 |
accelerator: null
|
81 |
devices: 1
|
82 |
gradient_clip_val: 10.0
|
|
|
77 |
enable_model_summary: False
|
78 |
log_every_n_steps: 1 # Logs metrics every N batches
|
79 |
accumulate_grad_batches: 1
|
80 |
+
deterministic: True
|
81 |
accelerator: null
|
82 |
devices: 1
|
83 |
gradient_clip_val: 10.0
|
remfx/models.py
CHANGED
@@ -63,8 +63,9 @@ class RemFXModel(pl.LightningModule):
|
|
63 |
optimizer,
|
64 |
optimizer_idx,
|
65 |
optimizer_closure,
|
66 |
-
on_tpu
|
67 |
-
|
|
|
68 |
):
|
69 |
# update params
|
70 |
optimizer.step(closure=optimizer_closure)
|
|
|
63 |
optimizer,
|
64 |
optimizer_idx,
|
65 |
optimizer_closure,
|
66 |
+
on_tpu,
|
67 |
+
using_native_amp,
|
68 |
+
using_lbfgs,
|
69 |
):
|
70 |
# update params
|
71 |
optimizer.step(closure=optimizer_closure)
|
scripts/train.py
CHANGED
@@ -42,7 +42,7 @@ def main(cfg: DictConfig):
|
|
42 |
summary = ModelSummary(model)
|
43 |
print(summary)
|
44 |
trainer.fit(model=model, datamodule=datamodule)
|
45 |
-
trainer.test(model=model, datamodule=datamodule)
|
46 |
|
47 |
|
48 |
if __name__ == "__main__":
|
|
|
42 |
summary = ModelSummary(model)
|
43 |
print(summary)
|
44 |
trainer.fit(model=model, datamodule=datamodule)
|
45 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
|
46 |
|
47 |
|
48 |
if __name__ == "__main__":
|