File size: 1,050 Bytes
e6fd727
557fb53
0030bc6
e6fd727
557fb53
e6fd727
557fb53
e6fd727
3b31903
 
557fb53
 
 
c914273
3b31903
 
0030bc6
 
 
c914273
e6fd727
c914273
3b31903
 
 
 
 
 
 
 
e6fd727
 
557fb53
 
 
3b31903
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from typing import Callable
import importlib
import yaml
from argparse import ArgumentParser
import os

ROOT_DIR = os.path.basename(os.path.dirname(__file__))


def get_training_fn(id: str) -> Callable:
    module_name, fn_name = id.rsplit(".", 1)
    module = importlib.import_module("models." + module_name, ROOT_DIR)
    return getattr(module, fn_name)


def get_config(filepath: str) -> dict:
    with open(filepath, "r") as f:
        config = yaml.safe_load(f)
    return config


if __name__ == "__main__":
    parser = ArgumentParser(
        description="Trains models on the dance dataset and saves weights."
    )
    parser.add_argument(
        "--config",
        help="Path to the yaml file that defines the training configuration.",
        default="models/config/train_local.yaml",
    )
    args = parser.parse_args()
    config = get_config(args.config)
    training_fn_path = config["training_fn"]
    print(f"Config: {args.config}\nTrainer Id: {training_fn_path}")
    train = get_training_fn(training_fn_path)
    train(config)