File size: 4,532 Bytes
5019931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import argparse
import os
import pathlib
import time
from typing import NoReturn

import librosa
import numpy as np
import soundfile
import torch

from bytesep.inference import Separator
from bytesep.models.lightning_modules import get_model_class
from bytesep.utils import read_yaml


def inference(args) -> NoReturn:
    r"""Separate all audios in a directory.

    Args:
        config_yaml: str, the config file of a model being trained
        checkpoint_path: str, the path of checkpoint to be loaded
        audios_dir: str, the directory of audios to be separated
        output_dir: str, the directory to write out separated audios
        scale_volume: bool, if True then the volume is scaled to the maximum value of 1.

    Returns:
        NoReturn
    """

    # Arguments & parameters
    config_yaml = args.config_yaml
    checkpoint_path = args.checkpoint_path
    audios_dir = args.audios_dir
    output_dir = args.output_dir
    scale_volume = args.scale_volume
    device = (
        torch.device('cuda')
        if args.cuda and torch.cuda.is_available()
        else torch.device('cpu')
    )

    configs = read_yaml(config_yaml)
    sample_rate = configs['train']['sample_rate']
    input_channels = configs['train']['channels']
    target_source_types = configs['train']['target_source_types']
    target_sources_num = len(target_source_types)
    model_type = configs['train']['model_type']
    mono = input_channels == 1

    segment_samples = int(30 * sample_rate)
    batch_size = 1
    device = "cuda"

    models_contains_inplaceabn = True

    # Need to use torch.distributed if models contain inplace_abn.abn.InPlaceABNSync.
    if models_contains_inplaceabn:

        import torch.distributed as dist

        dist.init_process_group(
            'gloo', init_method='file:///tmp/somefile', rank=0, world_size=1
        )

    print("Using {} for separating ..".format(device))

    # paths
    os.makedirs(output_dir, exist_ok=True)

    # Get model class.
    Model = get_model_class(model_type)

    # Create model.
    model = Model(input_channels=input_channels, target_sources_num=target_sources_num)

    # Load checkpoint.
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint["model"])

    # Move model to device.
    model.to(device)

    # Create separator.
    separator = Separator(
        model=model,
        segment_samples=segment_samples,
        batch_size=batch_size,
        device=device,
    )

    audio_names = sorted(os.listdir(audios_dir))

    for audio_name in audio_names:
        audio_path = os.path.join(audios_dir, audio_name)

        # Load audio.
        audio, _ = librosa.load(audio_path, sr=sample_rate, mono=mono)

        if audio.ndim == 1:
            audio = audio[None, :]

        input_dict = {'waveform': audio}

        # Separate
        separate_time = time.time()

        sep_wav = separator.separate(input_dict)
        # (channels_num, audio_samples)

        print('Separate time: {:.3f} s'.format(time.time() - separate_time))

        # Write out separated audio.
        if scale_volume:
            sep_wav /= np.max(np.abs(sep_wav))

        soundfile.write(file='_zz.wav', data=sep_wav.T, samplerate=sample_rate)

        output_path = os.path.join(
            output_dir, '{}.mp3'.format(pathlib.Path(audio_name).stem)
        )
        os.system('ffmpeg -y -loglevel panic -i _zz.wav "{}"'.format(output_path))
        print('Write out to {}'.format(output_path))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="")
    parser.add_argument(
        "--config_yaml",
        type=str,
        required=True,
        help="The config file of a model being trained.",
    )
    parser.add_argument(
        "--checkpoint_path",
        type=str,
        required=True,
        help="The path of checkpoint to be loaded.",
    )
    parser.add_argument(
        "--audios_dir",
        type=str,
        required=True,
        help="The directory of audios to be separated.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="The directory to write out separated audios.",
    )
    parser.add_argument(
        '--scale_volume',
        action='store_true',
        default=False,
        help="set to True if separated audios are scaled to the maximum value of 1.",
    )
    parser.add_argument("--cuda", action='store_true', default=True)

    args = parser.parse_args()

    inference(args)