File size: 2,186 Bytes
5019931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import List

import pytorch_lightning as pl
import torch.nn as nn


def get_callbacks(
    task_name: str,
    config_yaml: str,
    workspace: str,
    checkpoints_dir: str,
    statistics_path: str,
    logger: pl.loggers.TensorBoardLogger,
    model: nn.Module,
    evaluate_device: str,
) -> List[pl.Callback]:
    r"""Get callbacks of a task and config yaml file.

    Args:
        task_name: str
        config_yaml: str
        dataset_dir: str
        workspace: str, containing useful files such as audios for evaluation
        checkpoints_dir: str, directory to save checkpoints
        statistics_dir: str, directory to save statistics
        logger: pl.loggers.TensorBoardLogger
        model: nn.Module
        evaluate_device: str

    Return:
        callbacks: List[pl.Callback]
    """
    if task_name == 'musdb18':

        from bytesep.callbacks.musdb18 import get_musdb18_callbacks

        return get_musdb18_callbacks(
            config_yaml=config_yaml,
            workspace=workspace,
            checkpoints_dir=checkpoints_dir,
            statistics_path=statistics_path,
            logger=logger,
            model=model,
            evaluate_device=evaluate_device,
        )

    elif task_name == 'voicebank-demand':

        from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks

        return get_voicebank_demand_callbacks(
            config_yaml=config_yaml,
            workspace=workspace,
            checkpoints_dir=checkpoints_dir,
            statistics_path=statistics_path,
            logger=logger,
            model=model,
            evaluate_device=evaluate_device,
        )

    elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']:

        from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks

        return get_instruments_callbacks(
            config_yaml=config_yaml,
            workspace=workspace,
            checkpoints_dir=checkpoints_dir,
            statistics_path=statistics_path,
            logger=logger,
            model=model,
            evaluate_device=evaluate_device,
        )

    else:
        raise NotImplementedError