|
import os
|
|
import subprocess as sp
|
|
import sys
|
|
import time
|
|
from datetime import timedelta
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from random import Random
|
|
|
|
import click
|
|
import numpy as np
|
|
import torch
|
|
import torchaudio
|
|
from hydra import compose, initialize
|
|
from hydra.utils import instantiate
|
|
from lightning import LightningModule
|
|
from loguru import logger
|
|
from omegaconf import OmegaConf
|
|
|
|
from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
|
|
|
|
|
|
OmegaConf.register_new_resolver("eval", eval)
|
|
|
|
|
|
|
|
|
|
RANK = int(os.environ.get("SLURM_PROCID", 0))
|
|
WORLD_SIZE = int(os.environ.get("SLURM_NTASKS", 1))
|
|
|
|
logger_format = (
|
|
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
|
"<level>{level: <8}</level> | "
|
|
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> | "
|
|
"{extra[rank]} - <level>{message}</level>"
|
|
)
|
|
logger.configure(extra={"rank": f"RANK: {RANK} / {WORLD_SIZE}"})
|
|
logger.remove()
|
|
logger.add(sys.stderr, format=logger_format)
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def get_model(
|
|
config_name: str = "firefly_gan_vq",
|
|
checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
device: str | torch.device = "cuda",
|
|
):
|
|
with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
|
|
cfg = compose(config_name=config_name)
|
|
|
|
model = instantiate(cfg)
|
|
state_dict = torch.load(
|
|
checkpoint_path,
|
|
map_location=device,
|
|
)
|
|
if "state_dict" in state_dict:
|
|
state_dict = state_dict["state_dict"]
|
|
|
|
if any("generator" in k for k in state_dict):
|
|
state_dict = {
|
|
k.replace("generator.", ""): v
|
|
for k, v in state_dict.items()
|
|
if "generator." in k
|
|
}
|
|
|
|
model.load_state_dict(state_dict, strict=False)
|
|
model.eval()
|
|
model.to(device)
|
|
|
|
logger.info(f"Loaded model")
|
|
return model
|
|
|
|
|
|
@torch.inference_mode()
|
|
def process_batch(files: list[Path], model) -> float:
|
|
wavs = []
|
|
audio_lengths = []
|
|
new_files = []
|
|
max_length = total_time = 0
|
|
|
|
for file in files:
|
|
try:
|
|
wav, sr = torchaudio.load(
|
|
str(file), backend="sox" if sys.platform == "linux" else "soundfile"
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error reading {file}: {e}")
|
|
continue
|
|
|
|
if wav.shape[0] > 1:
|
|
wav = wav.mean(dim=0, keepdim=True)
|
|
|
|
wav = torchaudio.functional.resample(
|
|
wav.cuda(), sr, model.spec_transform.sample_rate
|
|
)[0]
|
|
total_time += len(wav) / model.spec_transform.sample_rate
|
|
max_length = max(max_length, len(wav))
|
|
|
|
wavs.append(wav)
|
|
audio_lengths.append(len(wav))
|
|
new_files.append(file)
|
|
|
|
files = new_files
|
|
|
|
|
|
for i, wav in enumerate(wavs):
|
|
wavs[i] = torch.nn.functional.pad(wav, (0, max_length - len(wav)), "constant")
|
|
|
|
audios = torch.stack(wavs, dim=0)[:, None]
|
|
audio_lengths = torch.tensor(audio_lengths, device=model.device, dtype=torch.long)
|
|
|
|
|
|
indices, feature_lengths = model.encode(audios, audio_lengths)
|
|
|
|
|
|
outputs = indices.cpu().numpy()
|
|
|
|
for file, length, feature, audio_length in zip(
|
|
files, feature_lengths, outputs, audio_lengths
|
|
):
|
|
feature = feature[:, :length]
|
|
|
|
|
|
with open(file.with_suffix(".npy"), "wb") as f:
|
|
np.save(f, feature)
|
|
|
|
return total_time
|
|
|
|
|
|
@click.command()
|
|
@click.argument("folder")
|
|
@click.option("--num-workers", default=1)
|
|
@click.option("--config-name", default="firefly_gan_vq")
|
|
@click.option(
|
|
"--checkpoint-path",
|
|
default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
|
|
)
|
|
@click.option("--batch-size", default=64)
|
|
@click.option("--filelist", default=None, type=Path)
|
|
def main(
|
|
folder: str,
|
|
num_workers: int,
|
|
config_name: str,
|
|
checkpoint_path: str,
|
|
batch_size: int,
|
|
filelist: Path,
|
|
):
|
|
if num_workers > 1 and WORLD_SIZE != num_workers:
|
|
assert WORLD_SIZE == 1, "You should either use SLURM or this launcher, not both"
|
|
|
|
logger.info(f"Spawning {num_workers} workers")
|
|
|
|
if torch.cuda.is_available():
|
|
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
|
if visible_devices is None:
|
|
visible_devices = list(range(torch.cuda.device_count()))
|
|
else:
|
|
visible_devices = visible_devices.split(",")
|
|
else:
|
|
|
|
visible_devices = [""]
|
|
|
|
processes = []
|
|
for i in range(num_workers):
|
|
env = os.environ.copy()
|
|
env["CUDA_VISIBLE_DEVICES"] = str(visible_devices[i % len(visible_devices)])
|
|
env["SLURM_PROCID"] = str(i)
|
|
env["SLURM_NTASKS"] = str(num_workers)
|
|
|
|
processes.append(
|
|
sp.Popen(
|
|
[sys.executable] + sys.argv.copy(),
|
|
env=env,
|
|
)
|
|
)
|
|
|
|
for p in processes:
|
|
p.wait()
|
|
|
|
logger.info(f"All workers finished")
|
|
return
|
|
|
|
|
|
logger.info(f"Starting worker")
|
|
if filelist:
|
|
files = [i[0] for i in load_filelist(filelist)]
|
|
else:
|
|
files = list_files(folder, AUDIO_EXTENSIONS, recursive=True, sort=False)
|
|
|
|
print(f"Found {len(files)} files")
|
|
files = [Path(f) for f in files if not Path(f).with_suffix(".npy").exists()]
|
|
|
|
total_files = len(files)
|
|
files = files[RANK::WORLD_SIZE]
|
|
logger.info(f"Processing {len(files)}/{total_files} files")
|
|
|
|
|
|
total_time = 0
|
|
begin_time = time.time()
|
|
processed_files = 0
|
|
model = get_model(config_name, checkpoint_path)
|
|
|
|
for n_batch, idx in enumerate(range(0, len(files), batch_size)):
|
|
batch = files[idx : idx + batch_size]
|
|
batch_time = process_batch(batch, model)
|
|
|
|
total_time += batch_time
|
|
processed_files += len(batch)
|
|
|
|
if (n_batch + 1) % 10 == 0:
|
|
eta = (
|
|
(time.time() - begin_time)
|
|
/ processed_files
|
|
* (len(files) - processed_files)
|
|
)
|
|
logger.info(
|
|
f"Processed {processed_files} files, {total_time / 3600:.2f} hours of audio, "
|
|
+ f"ETA: {timedelta(seconds=round(eta))}s"
|
|
)
|
|
|
|
logger.info(
|
|
f"Finished processing {len(files)} files, {total_time / 3600:.2f} hours of audio"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|