File size: 3,144 Bytes
d358e26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import glob
import os
from tqdm import tqdm
from dataclasses import dataclass
import torch
from torch import Tensor
from torch.multiprocessing import Pool, set_start_method
import torchaudio
from config import MelConfig,  TrainConfig

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

@dataclass
class DataConfig:
    audio_dir = './audios' # path to audios
    output_dir = './vocos_datasets' # path to save processed audios
    filelist_path = './filelists/filelist.txt' # path to save filelist
            
data_config = DataConfig()
train_config = TrainConfig()
mel_config = MelConfig()
audio_dir = data_config.audio_dir
output_dir = data_config.output_dir
filelist_path = data_config.filelist_path
segment_size = train_config.segment_size

output_audio_dir = os.path.join(output_dir, 'audios')

# Ensure output directories exist
os.makedirs(output_audio_dir, exist_ok=True)
os.makedirs(os.path.dirname(filelist_path), exist_ok=True)

def load_and_resample_audio(audio_path, target_sr, segment_size, device='cpu') -> Tensor:
    try:
        y, sr = torchaudio.load(audio_path)
    except Exception as e:
        print(str(e))
        return None
    
    y.to(device)
    
    # Convert to mono
    if y.size(0) > 1:
        y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
        
    # resample audio to target sample_rate
    if sr != target_sr:
        y = torchaudio.functional.resample(y, sr, target_sr)
        
    if y.size(-1) < segment_size:
        y = torch.nn.functional.pad(y, (0, segment_size - y.size(-1)), "constant", 0)
        
    return y
    
def find_audio_files(directory) -> list:
    extensions = ['wav', 'mp3', 'flac']
    files_found = []
    for extension in extensions:
        pattern = os.path.join(directory, '**', f'*.{extension}')
        files_found.extend(glob.glob(pattern, recursive=True))
    return files_found

@ torch.inference_mode()
def process_audio(audio_path):
    audio = load_and_resample_audio(audio_path, mel_config.sample_rate, segment_size, device=device) # shape: [1, time]
    if audio is not None:
        
        # get output path
        audio_name, _ = os.path.splitext(os.path.basename(audio_path))
        output_audio_path = os.path.join(output_audio_dir, audio_name + '.wav')
        
        # save resampled audio and mel features
        torchaudio.save(output_audio_path, audio.cpu(), mel_config.sample_rate)
        
        return output_audio_path

def main():
    set_start_method('spawn') # CUDA must use spawn method
    audio_files = find_audio_files(audio_dir)
    results = []
    
    with Pool(processes=8) as pool:
        for result in tqdm(pool.imap(process_audio, audio_files), total=len(audio_files)):
            if result is not None:
                results.append(f'{result}\n') 
            
    # save filelist
    with open(filelist_path, 'w', encoding='utf-8') as f:
        f.writelines(results)
    print(f"filelist file has been saved to {filelist_path}")

# faster and use much less CPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
    
if __name__ == '__main__':
    main()