# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch import os import pickle from tqdm import tqdm import numpy as np from modules import whisper_extractor as whisper def whisper_encoder_batch(model, audio_paths): batch = len(audio_paths) batch_mel = torch.zeros((batch, 80, 3000), dtype=torch.float32, device=model.device) for i, audio_path in enumerate(audio_paths): # (48000,) audio = whisper.load_audio(str(audio_path)) audio = whisper.pad_or_trim(audio) # (80, 3000) mel = whisper.log_mel_spectrogram(audio).to(model.device) batch_mel[i] = mel with torch.no_grad(): # (batch, 1500, 1024) features = model.embed_audio(batch_mel) return features.cpu().detach().numpy() def whisper_encoder(model, audio_path): audio = whisper.load_audio(str(audio_path)) audio = whisper.pad_or_trim(audio) # (80, 3000) mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) with torch.no_grad(): # (1, 1500, 1024) -> # (1500, 1024) features = model.embed_audio(mel).squeeze(0) return features.cpu().detach().numpy() def get_mapped_whisper_features( raw_whisper_features, mapping_features, fast_mapping=True ): """ Whisper: frameshift = 20ms (30s audio -> 1500 frames), hop_size = 480 in 24k # Ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/model.py#L136 Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms) """ source_hop = 480 target_hop = 256 factor = np.gcd(source_hop, target_hop) source_hop //= factor target_hop //= factor print( "Mapping source's {} frames => target's {} frames".format( target_hop, source_hop ) ) max_source_len = 1500 whisper_features = [] for index, mapping_feat in enumerate(tqdm(mapping_features)): # mapping_feat: (mels_frame_len, n_mels) target_len = mapping_feat.shape[0] # The max target_len is 2812 target_len = min(target_len, max_source_len * source_hop // target_hop) # (1500, dim) raw_feats = raw_whisper_features[index] width = raw_feats.shape[-1] if fast_mapping: source_len = target_len * target_hop // source_hop + 1 raw_feats = raw_feats[:source_len] else: source_len = max_source_len # const ~= target_len * target_hop const = source_len * source_hop // target_hop * target_hop # (source_len * source_hop, dim) up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) down_sampling_feats = np.average( up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 ) assert len(down_sampling_feats) >= target_len # (target_len, dim) feats = down_sampling_feats[:target_len] whisper_features.append(feats) return whisper_features def load_whisper_model(hps): print("Loading Whisper Model: ", hps.whisper_model) model = whisper.load_model(hps.whisper_model) if torch.cuda.is_available(): model = model.cuda() model = model.eval() return model def load_target_acoustic_features( output_path, dataset, acoustic_features_name, acoustic_features_fs, dataset_type ): mapping_dir = os.path.join( output_path, dataset, "{}/{}".format(acoustic_features_name, acoustic_features_fs), ) with open(os.path.join(mapping_dir, "{}.pkl".format(dataset_type)), "rb") as f: mapping_features = pickle.load(f) # Mels: (n_mels, frame_len) -> (frame_len, n_mels) if acoustic_features_name == "mels": print("Transposing mel features...") mapping_features = [feat.T for feat in mapping_features] print( "Mapping to the acoustic features {}, #sz = {}, feats[0] is {}".format( acoustic_features_name, len(mapping_features), mapping_features[0].shape ) ) return mapping_features def extract_whisper_features_of_dataset( datasets, model, batch_size, out_dir, ): audio_paths = [utt["Path"] for utt in datasets] if len(audio_paths) < batch_size: batch_size = len(audio_paths) start, end = 0, 0 while end < len(audio_paths): # Raw features: (batch_size, 1500, dim) start = end end = start + batch_size tmp_raw_whisper_features = whisper_encoder_batch(model, audio_paths[start:end]) # Mapping to acoustic features' lengths for index, utt in enumerate(tqdm(datasets[start:end])): uid = utt["Uid"] raw_whisper_feature = tmp_raw_whisper_features[index] save_path = os.path.join(out_dir, uid + ".npy") np.save(save_path, raw_whisper_feature) print("{}/{} Done...".format(end, len(audio_paths)))