NeuCoSVC-Colab / dataset /prematch_dataset.py
kevinwang676's picture
Upload folder using huggingface_hub
cfdc687
import argparse
import os
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch import Tensor
from tqdm import tqdm
import resampy
from modules.wavlm_encoder import WavLMEncoder
from utils.tools import fast_cosine_dist
DOWNSAMPLE_FACTOR = 320
def make_opensinger_df(root_path: Path) -> pd.DataFrame:
all_files = []
folders = ['ManRaw', 'WomanRaw']
for f in folders:
all_files.extend(list((root_path/f).rglob('*.wav')))
# f.parts[-3][:-3]: Man/Woman
speakers = [f.parts[-3][:-3] + '-' + f.stem.split('_')[0] for f in all_files]
df = pd.DataFrame({'path': all_files, 'speaker': speakers})
return df
def main(args):
data_root = Path(args.data_root)
out_dir = Path(args.out_dir) if args.out_dir is not None else data_root/'wavlm_features'
device = torch.device(args.device)
seed = args.seed
SYNTH_WEIGHTINGS = F.one_hot(torch.tensor(args.synthesis_layer), num_classes=25).float().to(device)[:, None]
MATCH_WEIGHTINGS = F.one_hot(torch.tensor(args.matching_layer), num_classes=25).float().mean(axis=0).to(device)[:, None]
print(f"Matching weight: {MATCH_WEIGHTINGS.squeeze()}\nSynthesis weight: {SYNTH_WEIGHTINGS.squeeze()}")
ls_df = make_opensinger_df(data_root)
print(f"Loading wavlm.")
wavlm = WavLMEncoder('pretrained/WavLM-Large.pt', device=device)
np.random.seed(seed)
torch.manual_seed(seed)
extract(ls_df, wavlm, device, data_root, out_dir, SYNTH_WEIGHTINGS, MATCH_WEIGHTINGS)
print("All done!", flush=True)
@torch.inference_mode()
def get_full_features(path, wavlm, device):
x, sr = torchaudio.load(path)
if sr != 16000:
x = resampy.resample(x.numpy(), sr, 16000, axis=1)
x = torch.from_numpy(x).to(dtype=torch.float)
n_pad = DOWNSAMPLE_FACTOR - (x.shape[-1] % DOWNSAMPLE_FACTOR)
x = F.pad(x, (0, n_pad), value=0)
# extract the representation of each layer
wav_input_16khz = x.to(device)
features = wavlm.get_features(wav_input_16khz)
return features
@torch.inference_mode()
def extract(df: pd.DataFrame, wavlm: nn.Module, device, data_root: Path, out_dir: Path, synth_weights: Tensor, match_weights: Tensor):
mb = tqdm(df.groupby('speaker'), desc=f'Total Progress')
for speaker, paths in mb:
if len(paths) == 1:
print(f"there is only one audio for speaker {speaker}, ignore him")
continue
targ_paths = {}
for i, row in paths.iterrows():
rel_path = row.path.relative_to(data_root)
targ_paths[row.path] = (out_dir/rel_path).with_suffix('.pt')
if all([p.exists() for p in targ_paths.values()]):
continue
feature_cache = {}
synthesis_cache = {}
# 1. extract the wavlm features of all the audio of the speaker
pb = tqdm(paths.iterrows(), total=len(paths), desc=f'extracting {speaker}')
for i, row in pb:
feats = get_full_features(row.path, wavlm, device)
matching_feats = (feats*match_weights[:, None] ).sum(dim=0) # (seq_len, dim)
synth_feats = (feats*synth_weights[:, None] ).sum(dim=0) # (seq_len, dim)
feature_cache[row.path] = matching_feats
synthesis_cache[row.path] = synth_feats
# 2. replace the wavlm features of each singing audio with the wavlm features of other songs by the same singer.
pb = tqdm(paths.iterrows(), total=len(paths), desc=f'prematching {speaker}')
for i, row in pb:
targ_path = targ_paths[row.path]
if targ_path.is_file(): continue
os.makedirs(targ_path.parent, exist_ok=True)
source_feats = feature_cache[row.path]
# the audios of the same song are removed since the same song contains repeated phrases.
song_name = row.path.stem.split('_')[1]
filtered_matching_feats = {key: value for key, value in feature_cache.items() if song_name not in key.stem}
matching_pool = list(filtered_matching_feats.values())
matching_pool = torch.concat(matching_pool, dim=0)
filtered_synth_feats = {key: value for key, value in synthesis_cache.items() if song_name not in key.stem}
synth_pool = list(filtered_synth_feats.values())
synth_pool = torch.concat(synth_pool, dim=0)
# calculate the distance and replace each feature with its K neighbors
matching_pool = matching_pool.to(device)
synth_pool = synth_pool.to(device)
dists = fast_cosine_dist(source_feats, matching_pool, device=device)
best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4)
out_feats = synth_pool[best.indices].mean(dim=1) # (N, dim)
# 3. save pre-matched sequence
torch.save(out_feats.cpu(), str(targ_path))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Compute matched wavlm features for a OpenSinger dataset")
parser.add_argument('--data_root', required=True, type=str)
parser.add_argument('--seed', default=123, type=int)
parser.add_argument('--out_dir', type=str)
parser.add_argument('--device', default='cuda', type=str)
parser.add_argument('--topk', type=int, default=4)
parser.add_argument('--matching_layer', type=int, default=[20,21,22,23,24], nargs='+')
parser.add_argument('--synthesis_layer', type=int, default=6)
args = parser.parse_args()
main(args)