# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import torch import numpy as np import yaml import copy from tqdm import tqdm from torchaudio.compliance import kaldi from torch.nn.utils.rnn import pad_sequence from torch.utils.data import DataLoader from fairseq import checkpoint_utils from transformers import AutoModel, Wav2Vec2FeatureExtractor from utils.io_optim import ( TorchaudioDataset, LibrosaDataset, FFmpegDataset, collate_batch, ) from modules import whisper_extractor as whisper from modules.wenet_extractor.utils.init_model import init_model from modules.wenet_extractor.utils.checkpoint import load_checkpoint """ Extractor for content features 1. whisper 2. contentvec 3. wenet 4. mert Pipeline: in preprocess.py: call extract_utt_content_features() to extract content features for each utterance extract_utt_content_features() envelopes the following steps: 1. load the model (whisper, contentvec, wenet) 2. extract the content features 3. save the content features into files in svc_dataset.py: call offline_align() to align the content features to the given target length """ """ Extractor Usage: 1. initialize an instance of extractor extractor = WhisperExtractor(cfg) 2. load the specified model extractor.load_model() 3. extract the content features extractor.extract_content(utt) for single utterance extractor.extract_content_batch(utts) for batch utterances 4. save the content features extractor.save_feature(utt, content_feature) for single utterance """ class BaseExtractor: def __init__(self, cfg): self.cfg = cfg self.extractor_type = None self.model = None def offline_align(self, content, target_len): """ args: content: (source_len, dim) target_len: target length return: mapped_feature: (target_len, dim) """ target_hop = self.cfg.preprocess.hop_size assert self.extractor_type in ["whisper", "contentvec", "wenet"] if self.extractor_type == "whisper": source_hop = ( self.cfg.preprocess.whisper_frameshift * self.cfg.preprocess.whisper_downsample_rate * self.cfg.preprocess.sample_rate ) elif self.extractor_type == "contentvec": source_hop = ( self.cfg.preprocess.contentvec_frameshift * self.cfg.preprocess.sample_rate ) elif self.extractor_type == "wenet": source_hop = ( self.cfg.preprocess.wenet_frameshift * self.cfg.preprocess.wenet_downsample_rate * self.cfg.preprocess.sample_rate ) source_hop = int(source_hop) factor = np.gcd(source_hop, target_hop) source_hop //= factor target_hop //= factor # (source_len, 256) _, width = content.shape # slice the content from padded feature source_len = min(target_len * target_hop // source_hop + 1, len(content)) # const ~= target_len * target_hop const = source_len * source_hop // target_hop * target_hop # (source_len * source_hop, dim) up_sampling_feats = np.repeat(content, source_hop, axis=0) # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) down_sampling_feats = np.average( up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 ) err = abs(target_len - len(down_sampling_feats)) if err > 8: # err_log_dir is indeterminate err_log_dir = os.path.join( self.cfg.preprocess.processed_dir, "align_max_err.log" ) try: with open(err_log_dir, "r") as f: err_num = int(f.read()) except: with open(err_log_dir, "w") as f: f.write("0") err_num = 0 if err > err_num: with open(err_log_dir, "w") as f: f.write(str(err)) if len(down_sampling_feats) < target_len: # (1, dim) -> (err, dim) end = down_sampling_feats[-1][None, :].repeat(err, axis=0) down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) # (target_len, dim) mapped_feature = down_sampling_feats[:target_len] return mapped_feature def save_feature(self, utt, content_feature): """Save a single utternace to path {cfg.preprocess.processed_dir} Args: utt (dict): one item in metadata, containing information for one utterance content_feature (tensor): content feature of one utterance """ uid = utt["Uid"] assert self.extractor_type != None out_dir = os.path.join( self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type ) os.makedirs(out_dir, exist_ok=True) save_path = os.path.join(out_dir, uid + ".npy") # only keep effective parts duration = utt["Duration"] if self.extractor_type == "whisper": frameshift = ( self.cfg.preprocess.whisper_frameshift * self.cfg.preprocess.whisper_downsample_rate ) # 20ms elif self.extractor_type == "contentvec": frameshift = self.cfg.preprocess.contentvec_frameshift # 20ms elif self.extractor_type == "wenet": frameshift = ( self.cfg.preprocess.wenet_frameshift * self.cfg.preprocess.wenet_downsample_rate ) # 40ms elif self.extractor_type == "mert": frameshift = self.cfg.preprocess.mert_frameshift else: raise NotImplementedError # calculate the number of valid frames num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1 # (num_frames, dim) -> (valid_frames, dim) assert ( len(content_feature.shape) == 2 ), "content feature shape error, it should be (num_frames, dim)" content_feature = content_feature[:num_frames, :] np.save(save_path, content_feature.cpu().detach().numpy()) class WhisperExtractor(BaseExtractor): def __init__(self, config): super(WhisperExtractor, self).__init__(config) self.extractor_type = "whisper" def load_model(self): # load whisper checkpoint print("Loading Whisper Model...") checkpoint_file = ( self.cfg.preprocess.whisper_model_path if "whisper_model_path" in self.cfg.preprocess else None ) model = whisper.load_model( self.cfg.preprocess.whisper_model, checkpoint_file=checkpoint_file ) if torch.cuda.is_available(): print("Using GPU...\n") model = model.cuda() else: print("Using CPU...\n") self.model = model.eval() def extract_content_features(self, wavs, lens): """extract content features from a batch of dataloader Args: wavs: tensor (batch_size, T) lens: list """ # wavs: (batch, max_len) wavs = whisper.pad_or_trim(wavs) # batch_mel: (batch, 80, 3000) batch_mel = whisper.log_mel_spectrogram(wavs).to(self.model.device) with torch.no_grad(): # (batch, 1500, 1024) features = self.model.embed_audio(batch_mel) return features class ContentvecExtractor(BaseExtractor): def __init__(self, cfg): super(ContentvecExtractor, self).__init__(cfg) self.extractor_type = "contentvec" def load_model(self): assert self.model == None # Load model ckpt_path = self.cfg.preprocess.contentvec_file print("Load Contentvec Model...") models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [ckpt_path], suffix="", ) model = models[0] model.eval() if torch.cuda.is_available(): # print("Using GPU...\n") model = model.cuda() self.model = model def extract_content_features(self, wavs, lens): """extract content features from a batch of dataloader Args: wavs: tensor (batch, T) lens: list """ device = next(self.model.parameters()).device wavs = wavs.to(device) # (batch, max_len) padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device) with torch.no_grad(): logits = self.model.extract_features( source=wavs, padding_mask=padding_mask, output_layer=12 ) # feats: (batch, T, 256) feats = self.model.final_proj(logits[0]) return feats class WenetExtractor(BaseExtractor): def __init__(self, config): super(WenetExtractor, self).__init__(config) self.extractor_type = "wenet" def load_model(self): wenet_cfg = self.cfg.preprocess.wenet_config wenet_model_path = self.cfg.preprocess.wenet_model_path # load Wenet config with open(wenet_cfg, "r") as w: wenet_configs = yaml.load(w, Loader=yaml.FullLoader) self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"]) print("Loading Wenet Model...") self.model = init_model(wenet_configs) load_checkpoint(self.model, wenet_model_path) if torch.cuda.is_available(): print("Using GPU...\n") self.model = self.model.cuda() else: print("Using CPU...\n") self.model = self.model.eval() def extract_content_features(self, wavs, lens): """extract content features from a batch of dataloader Args: wavs: tensor lens: list """ feats_list = [] lengths_list = [] device = next(self.model.parameters()).device # Extract fbank/mfcc features by kaldi assert self.extract_conf is not None, "load model first!" feats_type = self.extract_conf.get("feats_type", "fbank") assert feats_type in ["fbank", "mfcc"] for idx, wav in enumerate(wavs): # wav: (T) wav = wav[: lens[idx]].to(device) # pad one frame to compensate for the frame cut off after feature extraction pad_tensor = torch.zeros(160, device=wav.device) wav = torch.cat((wav, pad_tensor), dim=-1) wav *= 1 << 15 wav = wav.unsqueeze(0) # (T) -> (1, T) if feats_type == "fbank": fbank_conf = self.extract_conf.get("fbank_conf", {}) feat = kaldi.fbank( wav, sample_frequency=16000, num_mel_bins=fbank_conf["num_mel_bins"], frame_length=fbank_conf["frame_length"], frame_shift=fbank_conf["frame_shift"], dither=fbank_conf["dither"], ) elif feats_type == "mfcc": mfcc_conf = self.extract_conf.get("mfcc", {}) feat = kaldi.mfcc( wav, sample_frequency=16000, num_mel_bins=mfcc_conf["num_mel_bins"], frame_length=mfcc_conf["frame_length"], frame_shift=mfcc_conf["frame_shift"], dither=mfcc_conf["dither"], num_ceps=mfcc_conf.get("num_ceps", 40), high_freq=mfcc_conf.get("high_freq", 0.0), low_freq=mfcc_conf.get("low_freq", 20.0), ) feats_list.append(feat) lengths_list.append(feat.shape[0]) feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device) feats_tensor = pad_sequence(feats_list, batch_first=True).to( device ) # (batch, len, 80) features = self.model.encoder_extractor( feats_tensor, feats_lengths, decoding_chunk_size=-1, num_decoding_left_chunks=-1, simulate_streaming=False, ) return features class MertExtractor(BaseExtractor): def __init__(self, cfg): super(MertExtractor, self).__init__(cfg) self.extractor_type = "mert" self.preprocessor = None def load_model(self): assert self.model == None assert self.preprocessor == None print("Loading MERT Model: ...", self.cfg.preprocess.mert_model) local_mert_path = "/mnt/workspace/fangzihao/acce/Amphion/pretrained/MERT" model_name = self.cfg.preprocess.mert_model model = AutoModel.from_pretrained(local_mert_path, trust_remote_code=True) if torch.cuda.is_available(): model = model.cuda() preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( local_mert_path, trust_remote_code=True ) self.model = model self.preprocessor = preprocessor def extract_content_features(self, wavs, lens): """extract content features from a batch of dataloader Args: wavs: tensor (batch, T) lens: list """ with torch.no_grad(): sample_rate = self.preprocessor.sampling_rate device = next(self.model.parameters()).device assert ( sample_rate == self.cfg.preprocess.mert_sample_rate ), "mert sample rate mismatch, expected {}, got {}".format( self.cfg.preprocess.mert_sample_rate, sample_rate ) mert_features = [] # wav: (len) for wav in wavs: # {input_values: tensor, attention_mask: tensor} inputs = self.preprocessor( wavs, sampling_rate=sample_rate, return_tensors="pt" ).to(device) outputs = self.model(**inputs, output_hidden_states=True) # (25 layers, time steps, 1024 feature_dim) all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() # (1, frame_len, 1024) -> (frame_len, 1024) feature = outputs.hidden_states[ self.cfg.preprocess.mert_feature_layer ].squeeze(0) mert_features.append(feature) return mert_features def extract_utt_content_features_dataloader(cfg, metadata, num_workers): dataset_name = metadata[0]["Dataset"] if cfg.preprocess.extract_whisper_feature: feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "whisper") os.makedirs(feat_dir, exist_ok=True) feat_files_num = len(os.listdir(feat_dir)) if feat_files_num != len(metadata): whisper_waveforms = FFmpegDataset( cfg, dataset_name, cfg.preprocess.whisper_sample_rate, metadata=metadata ) data_loader = DataLoader( whisper_waveforms, num_workers=num_workers, shuffle=False, pin_memory=cfg.preprocess.pin_memory, batch_size=cfg.preprocess.content_feature_batch_size, collate_fn=collate_batch, drop_last=False, ) extractor = WhisperExtractor(cfg) extractor.load_model() for batch_idx, items in enumerate(tqdm(data_loader)): _metadata, wavs, lens = items batch_content_features = extractor.extract_content_features( wavs, lens, ) for index, utt in enumerate(_metadata): extractor.save_feature(utt, batch_content_features[index]) if cfg.preprocess.extract_contentvec_feature: feat_dir = os.path.join( cfg.preprocess.processed_dir, dataset_name, "contentvec" ) os.makedirs(feat_dir, exist_ok=True) feat_files_num = len(os.listdir(feat_dir)) if feat_files_num != len(metadata): contentvec_waveforms = LibrosaDataset( cfg, dataset_name, cfg.preprocess.contentvec_sample_rate, metadata=metadata, ) data_loader = DataLoader( contentvec_waveforms, num_workers=num_workers, shuffle=False, pin_memory=cfg.preprocess.pin_memory, batch_size=cfg.preprocess.content_feature_batch_size, collate_fn=collate_batch, drop_last=False, ) extractor = ContentvecExtractor(cfg) extractor.load_model() for batch_idx, items in enumerate(tqdm(data_loader)): _metadata, wavs, lens = items batch_content_features = extractor.extract_content_features(wavs, lens) for index, utt in enumerate(_metadata): extractor.save_feature(utt, batch_content_features[index]) if cfg.preprocess.extract_wenet_feature: feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet") os.makedirs(feat_dir, exist_ok=True) feat_files_num = len(os.listdir(feat_dir)) if feat_files_num != len(metadata): wenet_waveforms = TorchaudioDataset( cfg, dataset_name, cfg.preprocess.wenet_sample_rate, metadata=metadata ) data_loader = DataLoader( wenet_waveforms, num_workers=num_workers, shuffle=False, pin_memory=cfg.preprocess.pin_memory, batch_size=cfg.preprocess.content_feature_batch_size, collate_fn=collate_batch, drop_last=False, ) extractor = WenetExtractor(cfg) extractor.load_model() for batch_idx, items in enumerate(tqdm(data_loader)): _metadata, wavs, lens = items batch_content_features = extractor.extract_content_features( wavs, lens, ) for index, utt in enumerate(_metadata): extractor.save_feature(utt, batch_content_features[index]) if cfg.preprocess.extract_mert_feature: feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert") os.makedirs(feat_dir, exist_ok=True) feat_files_num = len(os.listdir(feat_dir)) if feat_files_num != len(metadata): mert_waveforms = TorchaudioDataset( cfg, dataset_name, cfg.preprocess.mert_sample_rate, metadata=metadata ) data_loader = DataLoader( mert_waveforms, num_workers=num_workers, shuffle=False, pin_memory=cfg.preprocess.pin_memory, batch_size=cfg.preprocess.content_feature_batch_size, collate_fn=collate_batch, drop_last=False, ) extractor = MertExtractor(cfg) extractor.load_model() for batch_idx, items in enumerate(tqdm(data_loader)): _metadata, wavs, lens = items batch_content_features = extractor.extract_content_features( wavs, lens, ) for index, utt in enumerate(_metadata): extractor.save_feature(utt, batch_content_features[index])