Spaces:
Running
Running
import argparse | |
from concurrent.futures import ThreadPoolExecutor | |
import warnings | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import utils | |
from common.log import logger | |
from common.stdout_wrapper import SAFE_STDOUT | |
from config import config | |
warnings.filterwarnings("ignore", category=UserWarning) | |
from pyannote.audio import Inference, Model | |
model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") | |
inference = Inference(model, window="whole") | |
device = torch.device(config.style_gen_config.device) | |
inference.to(device) | |
class NaNValueError(ValueError): | |
"""カスタム例外クラス。NaN値が見つかった場合に使用されます。""" | |
pass | |
# 推論時にインポートするために短いが関数を書く | |
def get_style_vector(wav_path): | |
return inference(wav_path) | |
def save_style_vector(wav_path): | |
try: | |
style_vec = get_style_vector(wav_path) | |
except Exception as e: | |
print("\n") | |
logger.error(f"Error occurred with file: {wav_path}, Details:\n{e}\n") | |
raise | |
# 値にNaNが含まれていると悪影響なのでチェックする | |
if np.isnan(style_vec).any(): | |
print("\n") | |
logger.warning(f"NaN value found in style vector: {wav_path}") | |
raise NaNValueError(f"NaN value found in style vector: {wav_path}") | |
np.save(f"{wav_path}.npy", style_vec) # `test.wav` -> `test.wav.npy` | |
def process_line(line): | |
wavname = line.split("|")[0] | |
try: | |
save_style_vector(wavname) | |
return line, None | |
except NaNValueError: | |
return line, "nan_error" | |
def save_average_style_vector(style_vectors, filename="style_vectors.npy"): | |
average_vector = np.mean(style_vectors, axis=0) | |
np.save(filename, average_vector) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-c", "--config", type=str, default=config.style_gen_config.config_path | |
) | |
parser.add_argument( | |
"--num_processes", type=int, default=config.style_gen_config.num_processes | |
) | |
args, _ = parser.parse_known_args() | |
config_path = args.config | |
num_processes = args.num_processes | |
hps = utils.get_hparams_from_file(config_path) | |
device = config.style_gen_config.device | |
training_lines = [] | |
with open(hps.data.training_files, encoding="utf-8") as f: | |
training_lines.extend(f.readlines()) | |
with ThreadPoolExecutor(max_workers=num_processes) as executor: | |
training_results = list( | |
tqdm( | |
executor.map(process_line, training_lines), | |
total=len(training_lines), | |
file=SAFE_STDOUT, | |
) | |
) | |
ok_training_lines = [line for line, error in training_results if error is None] | |
nan_training_lines = [ | |
line for line, error in training_results if error == "nan_error" | |
] | |
if nan_training_lines: | |
nan_files = [line.split("|")[0] for line in nan_training_lines] | |
logger.warning( | |
f"Found NaN value in {len(nan_training_lines)} files: {nan_files}, so they will be deleted from training data." | |
) | |
val_lines = [] | |
with open(hps.data.validation_files, encoding="utf-8") as f: | |
val_lines.extend(f.readlines()) | |
with ThreadPoolExecutor(max_workers=num_processes) as executor: | |
val_results = list( | |
tqdm( | |
executor.map(process_line, val_lines), | |
total=len(val_lines), | |
file=SAFE_STDOUT, | |
) | |
) | |
ok_val_lines = [line for line, error in val_results if error is None] | |
nan_val_lines = [line for line, error in val_results if error == "nan_error"] | |
if nan_val_lines: | |
nan_files = [line.split("|")[0] for line in nan_val_lines] | |
logger.warning( | |
f"Found NaN value in {len(nan_val_lines)} files: {nan_files}, so they will be deleted from validation data." | |
) | |
with open(hps.data.training_files, "w", encoding="utf-8") as f: | |
f.writelines(ok_training_lines) | |
with open(hps.data.validation_files, "w", encoding="utf-8") as f: | |
f.writelines(ok_val_lines) | |
ok_num = len(ok_training_lines) + len(ok_val_lines) | |
logger.info(f"Finished generating style vectors! total: {ok_num} npy files.") | |