File size: 3,304 Bytes
59b7eeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
from multiprocessing import Pool, cpu_count
import os
import torchaudio
import torch
from tqdm import tqdm
def process_file(args):
"""
Processes a single audio file to check for NaN values and get its length.
Args:
args (tuple): A tuple containing the file path and the root directory.
Returns:
str or None: A formatted string with the ABSOLUTE path and number of samples,
or None if the file is empty, contains NaN, or causes an error.
"""
file_path, root_dir = args # Unpack the arguments
try:
abs_path = os.path.abspath(file_path)
waveform, sample_rate = torchaudio.load(file_path)
if waveform.numel() == 0:
return None
flat_waveform = waveform.reshape(-1)
batch_size = 10000
for start in range(0, flat_waveform.numel(), batch_size):
end = min(start + batch_size, flat_waveform.numel())
if torch.isnan(flat_waveform[start:end]).any():
print(f"NaN found in: {abs_path}")
return None
nsample = waveform.shape[1]
return f"{abs_path}\t{nsample}\n"
except Exception as e:
print(f"Error processing {file_path}: {e}")
return None
def list_audio_files(root_dir, output_file, exclude_dirs=None):
"""
Lists audio files in a directory, processes them in parallel to get their
lengths, and writes the results to a file with ABSOLUTE paths.
Args:
root_dir (str): The root directory to search for audio files.
output_file (str): The path to the output file.
exclude_dirs (list, optional): A list of directories to exclude. Defaults to None.
"""
if exclude_dirs is None:
exclude_dirs = []
exclude_dirs = [os.path.abspath(d) for d in exclude_dirs]
audio_files = []
print("Finding audio files...")
for root, dirs, files in os.walk(root_dir, topdown=True):
# Exclude specified subdirectories
dirs[:] = [d for d in dirs if os.path.abspath(os.path.join(root, d)) not in exclude_dirs]
for filename in files:
if filename.lower().endswith(('.wav', '.flac', '.mp3')):
file_path = os.path.join(root, filename)
audio_files.append((file_path, root_dir))
audio_files.sort(key=lambda x: x[0])
print(f"Found {len(audio_files)} audio files to process.")
num_processes = max(1, int(cpu_count() / 2))
print(f"Starting processing with {num_processes} processes...")
with Pool(processes=num_processes) as pool:
results = list(tqdm(pool.imap(process_file, audio_files),
total=len(audio_files),
desc="Processing audio files"))
print(f"Writing results to {output_file}...")
with open(output_file, 'w', encoding='utf-8') as file:
for result in results:
if result:
file.write(result)
print("Processing complete.")
root_directory = '/home/ubuntu/respair/test_wav'
output_tsv = '/home/ubuntu/X-Codec-2.0/audio_high_quality_TEST.txt'
exclude_folders = ['']
list_audio_files(root_directory, output_tsv, exclude_dirs=exclude_folders) |