|
|
|
|
|
|
|
|
|
|
|
"""Resampling script. |
|
""" |
|
import argparse |
|
from pathlib import Path |
|
import shutil |
|
import typing as tp |
|
|
|
import submitit |
|
import tqdm |
|
|
|
from audiocraft.data.audio import audio_read, audio_write |
|
from audiocraft.data.audio_dataset import load_audio_meta, find_audio_files |
|
from audiocraft.data.audio_utils import convert_audio |
|
from audiocraft.environment import AudioCraftEnvironment |
|
|
|
|
|
def read_txt_files(path: tp.Union[str, Path]): |
|
with open(args.files_path) as f: |
|
lines = [line.rstrip() for line in f] |
|
print(f"Read {len(lines)} in .txt") |
|
lines = [line for line in lines if Path(line).suffix not in ['.json', '.txt', '.csv']] |
|
print(f"Filtered and keep {len(lines)} from .txt") |
|
return lines |
|
|
|
|
|
def read_egs_files(path: tp.Union[str, Path]): |
|
path = Path(path) |
|
if path.is_dir(): |
|
if (path / 'data.jsonl').exists(): |
|
path = path / 'data.jsonl' |
|
elif (path / 'data.jsonl.gz').exists(): |
|
path = path / 'data.jsonl.gz' |
|
else: |
|
raise ValueError("Don't know where to read metadata from in the dir. " |
|
"Expecting either a data.jsonl or data.jsonl.gz file but none found.") |
|
meta = load_audio_meta(path) |
|
return [m.path for m in meta] |
|
|
|
|
|
def process_dataset(args, n_shards: int, node_index: int, task_index: tp.Optional[int] = None): |
|
if task_index is None: |
|
env = submitit.JobEnvironment() |
|
task_index = env.global_rank |
|
shard_index = node_index * args.tasks_per_node + task_index |
|
|
|
if args.files_path is None: |
|
lines = [m.path for m in find_audio_files(args.root_path, resolve=False, progress=True, workers=8)] |
|
else: |
|
files_path = Path(args.files_path) |
|
if files_path.suffix == '.txt': |
|
print(f"Reading file list from .txt file: {args.files_path}") |
|
lines = read_txt_files(args.files_path) |
|
else: |
|
print(f"Reading file list from egs: {args.files_path}") |
|
lines = read_egs_files(args.files_path) |
|
|
|
total_files = len(lines) |
|
print( |
|
f"Total of {total_files} processed with {n_shards} shards. " + |
|
f"Current idx = {shard_index} -> {total_files // n_shards} files to process" |
|
) |
|
for idx, line in tqdm.tqdm(enumerate(lines)): |
|
|
|
|
|
if idx % n_shards != shard_index: |
|
continue |
|
|
|
path = str(AudioCraftEnvironment.apply_dataset_mappers(line)) |
|
root_path = str(args.root_path) |
|
if not root_path.endswith('/'): |
|
root_path += '/' |
|
assert path.startswith(str(root_path)), \ |
|
f"Mismatch between path and provided root: {path} VS {root_path}" |
|
|
|
try: |
|
metadata_path = Path(path).with_suffix('.json') |
|
out_path = args.out_path / path[len(root_path):] |
|
out_metadata_path = out_path.with_suffix('.json') |
|
out_done_token = out_path.with_suffix('.done') |
|
|
|
|
|
if out_done_token.exists(): |
|
continue |
|
|
|
print(idx, out_path, path) |
|
mix, sr = audio_read(path) |
|
mix_channels = args.channels if args.channels is not None and args.channels > 0 else mix.size(0) |
|
|
|
out_channels = mix_channels |
|
if out_channels > 2: |
|
print(f"Mix has more than two channels: {out_channels}, enforcing 2 channels") |
|
out_channels = 2 |
|
out_sr = args.sample_rate if args.sample_rate is not None else sr |
|
out_wav = convert_audio(mix, sr, out_sr, out_channels) |
|
audio_write(out_path.with_suffix(''), out_wav, sample_rate=out_sr, |
|
format=args.format, normalize=False, strategy='clip') |
|
if metadata_path.exists(): |
|
shutil.copy(metadata_path, out_metadata_path) |
|
else: |
|
print(f"No metadata found at {str(metadata_path)}") |
|
out_done_token.touch() |
|
except Exception as e: |
|
print(f"Error processing file line: {line}, {e}") |
|
|
|
|
|
if __name__ == '__main__': |
|
parser = argparse.ArgumentParser(description="Resample dataset with SLURM.") |
|
parser.add_argument( |
|
"--log_root", |
|
type=Path, |
|
default=Path.home() / 'tmp' / 'resample_logs', |
|
) |
|
parser.add_argument( |
|
"--files_path", |
|
type=Path, |
|
help="List of files to process, either .txt (one file per line) or a jsonl[.gz].", |
|
) |
|
parser.add_argument( |
|
"--root_path", |
|
type=Path, |
|
required=True, |
|
help="When rewriting paths, this will be the prefix to remove.", |
|
) |
|
parser.add_argument( |
|
"--out_path", |
|
type=Path, |
|
required=True, |
|
help="When rewriting paths, `root_path` will be replaced by this.", |
|
) |
|
parser.add_argument("--xp_name", type=str, default="shutterstock") |
|
parser.add_argument( |
|
"--nodes", |
|
type=int, |
|
default=4, |
|
) |
|
parser.add_argument( |
|
"--tasks_per_node", |
|
type=int, |
|
default=20, |
|
) |
|
parser.add_argument( |
|
"--cpus_per_task", |
|
type=int, |
|
default=4, |
|
) |
|
parser.add_argument( |
|
"--memory_gb", |
|
type=int, |
|
help="Memory in GB." |
|
) |
|
parser.add_argument( |
|
"--format", |
|
type=str, |
|
default="wav", |
|
) |
|
parser.add_argument( |
|
"--sample_rate", |
|
type=int, |
|
default=32000, |
|
) |
|
parser.add_argument( |
|
"--channels", |
|
type=int, |
|
) |
|
parser.add_argument( |
|
"--partition", |
|
default='learnfair', |
|
) |
|
parser.add_argument("--qos") |
|
parser.add_argument("--account") |
|
parser.add_argument("--timeout", type=int, default=4320) |
|
parser.add_argument('--debug', action='store_true', help='debug mode (local run)') |
|
args = parser.parse_args() |
|
n_shards = args.tasks_per_node * args.nodes |
|
if args.files_path is None: |
|
print("Warning: --files_path not provided, not recommended when processing more than 10k files.") |
|
if args.debug: |
|
print("Debugging mode") |
|
process_dataset(args, n_shards=n_shards, node_index=0, task_index=0) |
|
else: |
|
|
|
log_folder = Path(args.log_root) / args.xp_name / '%j' |
|
print(f"Logging to: {log_folder}") |
|
log_folder.parent.mkdir(parents=True, exist_ok=True) |
|
executor = submitit.AutoExecutor(folder=str(log_folder)) |
|
if args.qos: |
|
executor.update_parameters(slurm_partition=args.partition, slurm_qos=args.qos, slurm_account=args.account) |
|
else: |
|
executor.update_parameters(slurm_partition=args.partition) |
|
executor.update_parameters( |
|
slurm_job_name=args.xp_name, timeout_min=args.timeout, |
|
cpus_per_task=args.cpus_per_task, tasks_per_node=args.tasks_per_node, nodes=1) |
|
if args.memory_gb: |
|
executor.update_parameters(mem=f'{args.memory_gb}GB') |
|
jobs = [] |
|
with executor.batch(): |
|
for node_index in range(args.nodes): |
|
job = executor.submit(process_dataset, args, n_shards=n_shards, node_index=node_index) |
|
jobs.append(job) |
|
for job in jobs: |
|
print(f"Waiting on job {job.job_id}") |
|
job.results() |
|
|