Spaces:
Sleeping
Sleeping
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import os | |
import torch | |
import numpy as np | |
import yaml | |
import copy | |
from tqdm import tqdm | |
from torchaudio.compliance import kaldi | |
from torch.nn.utils.rnn import pad_sequence | |
from torch.utils.data import DataLoader | |
from fairseq import checkpoint_utils | |
from transformers import AutoModel, Wav2Vec2FeatureExtractor | |
from utils.io_optim import ( | |
TorchaudioDataset, | |
LibrosaDataset, | |
FFmpegDataset, | |
collate_batch, | |
) | |
import whisper | |
from modules.wenet_extractor.utils.init_model import init_model | |
from modules.wenet_extractor.utils.checkpoint import load_checkpoint | |
""" | |
Extractor for content features | |
1. whisper | |
2. contentvec | |
3. wenet | |
4. mert | |
Pipeline: | |
in preprocess.py: | |
call extract_utt_content_features() to extract content features for each utterance | |
extract_utt_content_features() envelopes the following steps: | |
1. load the model (whisper, contentvec, wenet) | |
2. extract the content features | |
3. save the content features into files | |
in svc_dataset.py: | |
call offline_align() to align the content features to the given target length | |
""" | |
""" | |
Extractor Usage: | |
1. initialize an instance of extractor | |
extractor = WhisperExtractor(cfg) | |
2. load the specified model | |
extractor.load_model() | |
3. extract the content features | |
extractor.extract_content(utt) for single utterance | |
extractor.extract_content_batch(utts) for batch utterances | |
4. save the content features | |
extractor.save_feature(utt, content_feature) for single utterance | |
""" | |
class BaseExtractor: | |
def __init__(self, cfg): | |
self.cfg = cfg | |
self.extractor_type = None | |
self.model = None | |
def offline_align(self, content, target_len): | |
""" | |
args: | |
content: (source_len, dim) | |
target_len: target length | |
return: | |
mapped_feature: (target_len, dim) | |
""" | |
target_hop = self.cfg.preprocess.hop_size | |
assert self.extractor_type in ["whisper", "contentvec", "wenet"] | |
if self.extractor_type == "whisper": | |
source_hop = ( | |
self.cfg.preprocess.whisper_frameshift | |
* self.cfg.preprocess.whisper_downsample_rate | |
* self.cfg.preprocess.sample_rate | |
) | |
elif self.extractor_type == "contentvec": | |
source_hop = ( | |
self.cfg.preprocess.contentvec_frameshift | |
* self.cfg.preprocess.sample_rate | |
) | |
elif self.extractor_type == "wenet": | |
source_hop = ( | |
self.cfg.preprocess.wenet_frameshift | |
* self.cfg.preprocess.wenet_downsample_rate | |
* self.cfg.preprocess.sample_rate | |
) | |
source_hop = int(source_hop) | |
factor = np.gcd(source_hop, target_hop) | |
source_hop //= factor | |
target_hop //= factor | |
# (source_len, 256) | |
_, width = content.shape | |
# slice the content from padded feature | |
source_len = min(target_len * target_hop // source_hop + 1, len(content)) | |
# const ~= target_len * target_hop | |
const = source_len * source_hop // target_hop * target_hop | |
# (source_len * source_hop, dim) | |
up_sampling_feats = np.repeat(content, source_hop, axis=0) | |
# (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) | |
down_sampling_feats = np.average( | |
up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 | |
) | |
err = abs(target_len - len(down_sampling_feats)) | |
if err > 8: | |
# err_log_dir is indeterminate | |
err_log_dir = os.path.join( | |
self.cfg.preprocess.processed_dir, "align_max_err.log" | |
) | |
try: | |
with open(err_log_dir, "r") as f: | |
err_num = int(f.read()) | |
except: | |
with open(err_log_dir, "w") as f: | |
f.write("0") | |
err_num = 0 | |
if err > err_num: | |
with open(err_log_dir, "w") as f: | |
f.write(str(err)) | |
if len(down_sampling_feats) < target_len: | |
# (1, dim) -> (err, dim) | |
end = down_sampling_feats[-1][None, :].repeat(err, axis=0) | |
down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) | |
# (target_len, dim) | |
mapped_feature = down_sampling_feats[:target_len] | |
return mapped_feature | |
def save_feature(self, utt, content_feature): | |
"""Save a single utternace to path {cfg.preprocess.processed_dir} | |
Args: | |
utt (dict): one item in metadata, containing information for one utterance | |
content_feature (tensor): content feature of one utterance | |
""" | |
uid = utt["Uid"] | |
assert self.extractor_type != None | |
out_dir = os.path.join( | |
self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type | |
) | |
os.makedirs(out_dir, exist_ok=True) | |
save_path = os.path.join(out_dir, uid + ".npy") | |
# only keep effective parts | |
duration = utt["Duration"] | |
if self.extractor_type == "whisper": | |
frameshift = ( | |
self.cfg.preprocess.whisper_frameshift | |
* self.cfg.preprocess.whisper_downsample_rate | |
) # 20ms | |
elif self.extractor_type == "contentvec": | |
frameshift = self.cfg.preprocess.contentvec_frameshift # 20ms | |
elif self.extractor_type == "wenet": | |
frameshift = ( | |
self.cfg.preprocess.wenet_frameshift | |
* self.cfg.preprocess.wenet_downsample_rate | |
) # 40ms | |
elif self.extractor_type == "mert": | |
frameshift = self.cfg.preprocess.mert_frameshift | |
else: | |
raise NotImplementedError | |
# calculate the number of valid frames | |
num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1 | |
# (num_frames, dim) -> (valid_frames, dim) | |
assert ( | |
len(content_feature.shape) == 2 | |
), "content feature shape error, it should be (num_frames, dim)" | |
content_feature = content_feature[:num_frames, :] | |
np.save(save_path, content_feature.cpu().detach().numpy()) | |
class WhisperExtractor(BaseExtractor): | |
def __init__(self, config): | |
super(WhisperExtractor, self).__init__(config) | |
self.extractor_type = "whisper" | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(self): | |
# load whisper checkpoint | |
print("Loading Whisper Model...") | |
if "whisper_model_path" in self.cfg.preprocess: | |
if os.path.isfile(self.cfg.preprocess.whisper_model_path): | |
# "pretrained/whisper/medium.pt" | |
download_root = os.path.dirname(self.cfg.preprocess.whisper_model_path) | |
elif os.path.isdir(self.cfg.preprocess.whisper_model_path): | |
# "pretrained/whisper" | |
download_root = self.cfg.preprocess.whisper_model_path | |
else: | |
download_root = None | |
model = whisper.load_model( | |
self.cfg.preprocess.whisper_model, self.device, download_root | |
) | |
if torch.cuda.is_available(): | |
print("Using GPU...\n") | |
model = model.cuda() | |
else: | |
print("Using CPU...\n") | |
self.model = model.eval() | |
def extract_content_features(self, wavs, lens): | |
"""extract content features from a batch of dataloader | |
Args: | |
wavs: tensor (batch_size, T) | |
lens: list | |
""" | |
# wavs: (batch, max_len) | |
wavs = whisper.pad_or_trim(wavs) | |
# batch_mel: (batch, 80, 3000) | |
batch_mel = whisper.log_mel_spectrogram(wavs, device=self.model.device) | |
with torch.no_grad(): | |
# (batch, 1500, 1024) | |
features = self.model.embed_audio(batch_mel) | |
return features | |
class ContentvecExtractor(BaseExtractor): | |
def __init__(self, cfg): | |
super(ContentvecExtractor, self).__init__(cfg) | |
self.extractor_type = "contentvec" | |
def load_model(self): | |
assert self.model == None | |
# Load model | |
ckpt_path = self.cfg.preprocess.contentvec_file | |
print("Load Contentvec Model...") | |
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( | |
[ckpt_path], | |
suffix="", | |
) | |
model = models[0] | |
model.eval() | |
if torch.cuda.is_available(): | |
# print("Using GPU...\n") | |
model = model.cuda() | |
self.model = model | |
def extract_content_features(self, wavs, lens): | |
"""extract content features from a batch of dataloader | |
Args: | |
wavs: tensor (batch, T) | |
lens: list | |
""" | |
device = next(self.model.parameters()).device | |
wavs = wavs.to(device) # (batch, max_len) | |
padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device) | |
with torch.no_grad(): | |
logits = self.model.extract_features( | |
source=wavs, padding_mask=padding_mask, output_layer=12 | |
) | |
# feats: (batch, T, 256) | |
feats = self.model.final_proj(logits[0]) | |
return feats | |
class WenetExtractor(BaseExtractor): | |
def __init__(self, config): | |
super(WenetExtractor, self).__init__(config) | |
self.extractor_type = "wenet" | |
def load_model(self): | |
wenet_cfg = self.cfg.preprocess.wenet_config | |
wenet_model_path = self.cfg.preprocess.wenet_model_path | |
# load Wenet config | |
with open(wenet_cfg, "r") as w: | |
wenet_configs = yaml.load(w, Loader=yaml.FullLoader) | |
self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"]) | |
print("Loading Wenet Model...") | |
self.model = init_model(wenet_configs) | |
load_checkpoint(self.model, wenet_model_path) | |
if torch.cuda.is_available(): | |
print("Using GPU...\n") | |
self.model = self.model.cuda() | |
else: | |
print("Using CPU...\n") | |
self.model = self.model.eval() | |
def extract_content_features(self, wavs, lens): | |
"""extract content features from a batch of dataloader | |
Args: | |
wavs: tensor | |
lens: list | |
""" | |
feats_list = [] | |
lengths_list = [] | |
device = next(self.model.parameters()).device | |
# Extract fbank/mfcc features by kaldi | |
assert self.extract_conf is not None, "load model first!" | |
feats_type = self.extract_conf.get("feats_type", "fbank") | |
assert feats_type in ["fbank", "mfcc"] | |
for idx, wav in enumerate(wavs): | |
# wav: (T) | |
wav = wav[: lens[idx]].to(device) | |
# pad one frame to compensate for the frame cut off after feature extraction | |
pad_tensor = torch.zeros(160, device=wav.device) | |
wav = torch.cat((wav, pad_tensor), dim=-1) | |
wav *= 1 << 15 | |
wav = wav.unsqueeze(0) # (T) -> (1, T) | |
if feats_type == "fbank": | |
fbank_conf = self.extract_conf.get("fbank_conf", {}) | |
feat = kaldi.fbank( | |
wav, | |
sample_frequency=16000, | |
num_mel_bins=fbank_conf["num_mel_bins"], | |
frame_length=fbank_conf["frame_length"], | |
frame_shift=fbank_conf["frame_shift"], | |
dither=fbank_conf["dither"], | |
) | |
elif feats_type == "mfcc": | |
mfcc_conf = self.extract_conf.get("mfcc", {}) | |
feat = kaldi.mfcc( | |
wav, | |
sample_frequency=16000, | |
num_mel_bins=mfcc_conf["num_mel_bins"], | |
frame_length=mfcc_conf["frame_length"], | |
frame_shift=mfcc_conf["frame_shift"], | |
dither=mfcc_conf["dither"], | |
num_ceps=mfcc_conf.get("num_ceps", 40), | |
high_freq=mfcc_conf.get("high_freq", 0.0), | |
low_freq=mfcc_conf.get("low_freq", 20.0), | |
) | |
feats_list.append(feat) | |
lengths_list.append(feat.shape[0]) | |
feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device) | |
feats_tensor = pad_sequence(feats_list, batch_first=True).to( | |
device | |
) # (batch, len, 80) | |
features = self.model.encoder_extractor( | |
feats_tensor, | |
feats_lengths, | |
decoding_chunk_size=-1, | |
num_decoding_left_chunks=-1, | |
simulate_streaming=False, | |
) | |
return features | |
class MertExtractor(BaseExtractor): | |
def __init__(self, cfg): | |
super(MertExtractor, self).__init__(cfg) | |
self.extractor_type = "mert" | |
self.preprocessor = None | |
def load_model(self): | |
assert self.model == None | |
assert self.preprocessor == None | |
print("Loading MERT Model: ...", self.cfg.preprocess.mert_model) | |
model_name = self.cfg.preprocess.mert_model | |
model = AutoModel.from_pretrained(model_name, trust_remote_code=True) | |
if torch.cuda.is_available(): | |
model = model.cuda() | |
preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( | |
model_name, trust_remote_code=True | |
) | |
self.model = model | |
self.preprocessor = preprocessor | |
def extract_content_features(self, wavs, lens): | |
"""extract content features from a batch of dataloader | |
Args: | |
wavs: tensor (batch, T) | |
lens: list | |
""" | |
with torch.no_grad(): | |
sample_rate = self.preprocessor.sampling_rate | |
device = next(self.model.parameters()).device | |
assert ( | |
sample_rate == self.cfg.preprocess.mert_sample_rate | |
), "mert sample rate mismatch, expected {}, got {}".format( | |
self.cfg.preprocess.mert_sample_rate, sample_rate | |
) | |
mert_features = [] | |
# wav: (len) | |
for wav in wavs: | |
# {input_values: tensor, attention_mask: tensor} | |
inputs = self.preprocessor( | |
wavs, sampling_rate=sample_rate, return_tensors="pt" | |
).to(device) | |
outputs = self.model(**inputs, output_hidden_states=True) | |
# (25 layers, time steps, 1024 feature_dim) | |
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() | |
# (1, frame_len, 1024) -> (frame_len, 1024) | |
feature = outputs.hidden_states[ | |
self.cfg.preprocess.mert_feature_layer | |
].squeeze(0) | |
mert_features.append(feature) | |
return mert_features | |
def extract_utt_content_features_dataloader(cfg, metadata, num_workers): | |
dataset_name = metadata[0]["Dataset"] | |
with torch.no_grad(): | |
if cfg.preprocess.extract_whisper_feature: | |
feat_dir = os.path.join( | |
cfg.preprocess.processed_dir, dataset_name, "whisper" | |
) | |
os.makedirs(feat_dir, exist_ok=True) | |
feat_files_num = len(os.listdir(feat_dir)) | |
if feat_files_num != len(metadata): | |
whisper_waveforms = FFmpegDataset( | |
cfg, | |
dataset_name, | |
cfg.preprocess.whisper_sample_rate, | |
metadata=metadata, | |
) | |
data_loader = DataLoader( | |
whisper_waveforms, | |
num_workers=num_workers, | |
shuffle=False, | |
pin_memory=cfg.preprocess.pin_memory, | |
batch_size=cfg.preprocess.content_feature_batch_size, | |
collate_fn=collate_batch, | |
drop_last=False, | |
) | |
extractor = WhisperExtractor(cfg) | |
extractor.load_model() | |
for batch_idx, items in enumerate(tqdm(data_loader)): | |
_metadata, wavs, lens = items | |
batch_content_features = extractor.extract_content_features( | |
wavs, | |
lens, | |
) | |
for index, utt in enumerate(_metadata): | |
extractor.save_feature(utt, batch_content_features[index]) | |
if cfg.preprocess.extract_contentvec_feature: | |
feat_dir = os.path.join( | |
cfg.preprocess.processed_dir, dataset_name, "contentvec" | |
) | |
os.makedirs(feat_dir, exist_ok=True) | |
feat_files_num = len(os.listdir(feat_dir)) | |
if feat_files_num != len(metadata): | |
contentvec_waveforms = LibrosaDataset( | |
cfg, | |
dataset_name, | |
cfg.preprocess.contentvec_sample_rate, | |
metadata=metadata, | |
) | |
data_loader = DataLoader( | |
contentvec_waveforms, | |
num_workers=num_workers, | |
shuffle=False, | |
pin_memory=cfg.preprocess.pin_memory, | |
batch_size=cfg.preprocess.content_feature_batch_size, | |
collate_fn=collate_batch, | |
drop_last=False, | |
) | |
extractor = ContentvecExtractor(cfg) | |
extractor.load_model() | |
for batch_idx, items in enumerate(tqdm(data_loader)): | |
_metadata, wavs, lens = items | |
batch_content_features = extractor.extract_content_features( | |
wavs, lens | |
) | |
for index, utt in enumerate(_metadata): | |
extractor.save_feature(utt, batch_content_features[index]) | |
if cfg.preprocess.extract_wenet_feature: | |
feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet") | |
os.makedirs(feat_dir, exist_ok=True) | |
feat_files_num = len(os.listdir(feat_dir)) | |
if feat_files_num != len(metadata): | |
wenet_waveforms = TorchaudioDataset( | |
cfg, | |
dataset_name, | |
cfg.preprocess.wenet_sample_rate, | |
metadata=metadata, | |
) | |
data_loader = DataLoader( | |
wenet_waveforms, | |
num_workers=num_workers, | |
shuffle=False, | |
pin_memory=cfg.preprocess.pin_memory, | |
batch_size=cfg.preprocess.content_feature_batch_size, | |
collate_fn=collate_batch, | |
drop_last=False, | |
) | |
extractor = WenetExtractor(cfg) | |
extractor.load_model() | |
for batch_idx, items in enumerate(tqdm(data_loader)): | |
_metadata, wavs, lens = items | |
batch_content_features = extractor.extract_content_features( | |
wavs, | |
lens, | |
) | |
for index, utt in enumerate(_metadata): | |
extractor.save_feature(utt, batch_content_features[index]) | |
if cfg.preprocess.extract_mert_feature: | |
feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert") | |
os.makedirs(feat_dir, exist_ok=True) | |
feat_files_num = len(os.listdir(feat_dir)) | |
if feat_files_num != len(metadata): | |
mert_waveforms = TorchaudioDataset( | |
cfg, | |
dataset_name, | |
cfg.preprocess.mert_sample_rate, | |
metadata=metadata, | |
) | |
data_loader = DataLoader( | |
mert_waveforms, | |
num_workers=num_workers, | |
shuffle=False, | |
pin_memory=cfg.preprocess.pin_memory, | |
batch_size=cfg.preprocess.content_feature_batch_size, | |
collate_fn=collate_batch, | |
drop_last=False, | |
) | |
extractor = MertExtractor(cfg) | |
extractor.load_model() | |
for batch_idx, items in enumerate(tqdm(data_loader)): | |
_metadata, wavs, lens = items | |
batch_content_features = extractor.extract_content_features( | |
wavs, | |
lens, | |
) | |
for index, utt in enumerate(_metadata): | |
extractor.save_feature(utt, batch_content_features[index]) | |