mattricesound commited on
Commit
e8a69d9
·
1 Parent(s): 61b9249

Log FAD only during test. Use rendered files during test.

Browse files
Files changed (4) hide show
  1. README.md +6 -5
  2. cfg/config.yaml +1 -1
  3. remfx/models.py +3 -0
  4. scripts/test.py +55 -0
README.md CHANGED
@@ -35,10 +35,11 @@ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' train
35
  - `reverb`
36
  - `all` (choose random effect to apply to each file)
37
 
 
 
 
 
38
  ## Misc.
39
  By default, files are rendered to `input_dir / processed / train/val/test`.
40
- To skip rendering files (use previously rendered), add `render_files=False` to the command-line
41
- To change the rendered location, add `render_root={path/to/dir}` to the command-line
42
- Test
43
- Experiment dictates data, ckpt dictates model
44
- `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
 
35
  - `reverb`
36
  - `all` (choose random effect to apply to each file)
37
 
38
+ ### Testing
39
+ Experiment dictates data, ckpt dictates model
40
+ `python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
41
+
42
  ## Misc.
43
  By default, files are rendered to `input_dir / processed / train/val/test`.
44
+ To skip rendering files (use previously rendered), add `render_files=False` to the command-line (added to test by default).
45
+ To change the rendered location, add `render_root={path/to/dir}` to the command-line (use this for train and test)
 
 
 
cfg/config.yaml CHANGED
@@ -19,7 +19,7 @@ callbacks:
19
  save_last: True # additionaly always save model from last epoch
20
  mode: "min" # can be "max" or "min"
21
  verbose: False
22
- dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
23
  filename: '{epoch:02d}-{valid_loss:.3f}'
24
 
25
  datamodule:
 
19
  save_last: True # additionaly always save model from last epoch
20
  mode: "min" # can be "max" or "min"
21
  verbose: False
22
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}-${exp}
23
  filename: '{epoch:02d}-{valid_loss:.3f}'
24
 
25
  datamodule:
remfx/models.py CHANGED
@@ -79,6 +79,9 @@ class RemFXModel(pl.LightningModule):
79
  negate = -1
80
  else:
81
  negate = 1
 
 
 
82
  self.log(
83
  f"{mode}_{metric}",
84
  negate * self.metrics[metric](output, y),
 
79
  negate = -1
80
  else:
81
  negate = 1
82
+ # Only Log FAD on test set
83
+ if metric == "FAD" and mode != "test":
84
+ continue
85
  self.log(
86
  f"{mode}_{metric}",
87
  negate * self.metrics[metric](output, y),
scripts/test.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import hydra
3
+ from omegaconf import DictConfig
4
+ import remfx.utils as utils
5
+ from pytorch_lightning.utilities.model_summary import ModelSummary
6
+ from remfx.models import RemFXModel
7
+ import torch
8
+
9
+ log = utils.get_logger(__name__)
10
+
11
+
12
+ @hydra.main(version_base=None, config_path="../cfg", config_name="config.yaml")
13
+ def main(cfg: DictConfig):
14
+ # Apply seed for reproducibility
15
+ if cfg.seed:
16
+ pl.seed_everything(cfg.seed)
17
+ cfg.render_files = False
18
+ log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
19
+ datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
20
+ log.info(f"Instantiating model <{cfg.model._target_}>.")
21
+ model = hydra.utils.instantiate(cfg.model, _convert_="partial")
22
+ state_dict = torch.load(cfg.ckpt_path, map_location=torch.device("cpu"))[
23
+ "state_dict"
24
+ ]
25
+ model.load_state_dict(state_dict)
26
+
27
+ # Init all callbacks
28
+ callbacks = []
29
+ if "callbacks" in cfg:
30
+ for _, cb_conf in cfg["callbacks"].items():
31
+ if "_target_" in cb_conf:
32
+ log.info(f"Instantiating callback <{cb_conf._target_}>.")
33
+ callbacks.append(hydra.utils.instantiate(cb_conf, _convert_="partial"))
34
+
35
+ logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
36
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
37
+ trainer = hydra.utils.instantiate(
38
+ cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
39
+ )
40
+ log.info("Logging hyperparameters!")
41
+ utils.log_hyperparameters(
42
+ config=cfg,
43
+ model=model,
44
+ datamodule=datamodule,
45
+ trainer=trainer,
46
+ callbacks=callbacks,
47
+ logger=logger,
48
+ )
49
+ summary = ModelSummary(model)
50
+ print(summary)
51
+ trainer.test(model=model, datamodule=datamodule)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()