maskgct-audio-lab / processors /content_extractor.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import numpy as np
import yaml
import copy
from tqdm import tqdm
from torchaudio.compliance import kaldi
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from fairseq import checkpoint_utils
from transformers import AutoModel, Wav2Vec2FeatureExtractor
from utils.io_optim import (
TorchaudioDataset,
LibrosaDataset,
FFmpegDataset,
collate_batch,
)
import whisper
from modules.wenet_extractor.utils.init_model import init_model
from modules.wenet_extractor.utils.checkpoint import load_checkpoint
"""
Extractor for content features
1. whisper
2. contentvec
3. wenet
4. mert
Pipeline:
in preprocess.py:
call extract_utt_content_features() to extract content features for each utterance
extract_utt_content_features() envelopes the following steps:
1. load the model (whisper, contentvec, wenet)
2. extract the content features
3. save the content features into files
in svc_dataset.py:
call offline_align() to align the content features to the given target length
"""
"""
Extractor Usage:
1. initialize an instance of extractor
extractor = WhisperExtractor(cfg)
2. load the specified model
extractor.load_model()
3. extract the content features
extractor.extract_content(utt) for single utterance
extractor.extract_content_batch(utts) for batch utterances
4. save the content features
extractor.save_feature(utt, content_feature) for single utterance
"""
class AudioPretrainedModelFeaturesExtractor:
def __init__(self, cfg, extractor_type):
self.cfg = cfg
self.extractor_type = extractor_type
self.model = None
self.init_for_retrans()
def init_for_retrans(self):
target_hop = self.cfg.preprocess.hop_size
assert self.extractor_type in ["whisper", "contentvec", "wenet"]
if self.extractor_type == "whisper":
source_hop = (
self.cfg.preprocess.whisper_frameshift
* self.cfg.preprocess.whisper_downsample_rate
* self.cfg.preprocess.sample_rate
)
elif self.extractor_type == "contentvec":
source_hop = (
self.cfg.preprocess.contentvec_frameshift
* self.cfg.preprocess.sample_rate
)
elif self.extractor_type == "wenet":
source_hop = (
self.cfg.preprocess.wenet_frameshift
* self.cfg.preprocess.wenet_downsample_rate
* self.cfg.preprocess.sample_rate
)
source_hop = int(source_hop)
factor = np.gcd(source_hop, target_hop)
source_hop //= factor
target_hop //= factor
self.source_hop = source_hop
self.target_hop = target_hop
def offline_resolution_transformation(self, content, target_len):
"""
args:
content: (source_len, dim)
target_len: target length
return:
mapped_feature: (target_len, dim)
"""
source_hop = self.source_hop
target_hop = self.target_hop
# (source_len, 256)
_, width = content.shape
# slice the content from padded feature
source_len = min(target_len * target_hop // source_hop + 1, len(content))
# const ~= target_len * target_hop
const = source_len * source_hop // target_hop * target_hop
# (source_len * source_hop, dim)
up_sampling_feats = np.repeat(content, source_hop, axis=0)
# (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim)
down_sampling_feats = np.average(
up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1
)
err = abs(target_len - len(down_sampling_feats))
if err > 8:
# err_log_dir is indeterminate
err_log_dir = os.path.join(
self.cfg.preprocess.processed_dir, "align_max_err.log"
)
try:
with open(err_log_dir, "r") as f:
err_num = int(f.read())
except:
with open(err_log_dir, "w") as f:
f.write("0")
err_num = 0
if err > err_num:
with open(err_log_dir, "w") as f:
f.write(str(err))
if len(down_sampling_feats) < target_len:
# (1, dim) -> (err, dim)
end = down_sampling_feats[-1][None, :].repeat(err, axis=0)
down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0)
# (target_len, dim)
mapped_feature = down_sampling_feats[:target_len]
return mapped_feature
def log_for_ReTrans(self, err):
err_log_dir = os.path.join(
self.cfg.preprocess.processed_dir, "align_max_err.log"
)
try:
with open(err_log_dir, "r") as f:
err_num = int(f.read())
except:
with open(err_log_dir, "w") as f:
f.write("0")
err_num = 0
if err > err_num:
with open(err_log_dir, "w") as f:
f.write(str(err))
def ReTrans(self, source_feats, padded_target_len):
"""
Resolution Transformation for mismatched frames alginment.
TODO: Merge the offline resolution_transformation into one
args:
source_feats: Tensor, (B, padded_source_len, D)
padded_target_len: int, the maximum target length in a batch
return:
mapped_feature: Tensor, (B, padded_target_len, D)
"""
source_hop = self.source_hop
target_hop = self.target_hop
# (B, padded_source_len, D)
B, padded_source_len, D = source_feats.shape
# select the valid content from padded feature
source_len = min(
padded_target_len * target_hop // source_hop + 1, padded_source_len
)
# const ~= padded_target_len * target_hop (padded wav's duration)
const = source_len * source_hop // target_hop * target_hop
# (B, padded_source_len, D) -> (B, padded_source_len * source_hop, D) -> (B, const, D)
up_sampling_feats = torch.repeat_interleave(source_feats, source_hop, dim=1)[
:, :const
]
# (B, const, D) -> (B, const/target_hop, target_hop, D) -> (B, const/target_hop, D)
down_sampling_feats = torch.mean(
up_sampling_feats.reshape(B, -1, target_hop, D), dim=2
)
err = abs(padded_target_len - down_sampling_feats.shape[1])
if err > 8:
self.log_for_ReTrans(err)
if down_sampling_feats.shape[1] < padded_target_len:
# (B, 1, D) -> (B, err, D)
end = down_sampling_feats[:, -1, :][:, None, :].repeat_interleave(
err, dim=1
)
# -> (B, padded_target_len, D)
down_sampling_feats = torch.cat([down_sampling_feats, end], dim=1)
# (B, padded_target_len, D)
mapped_feature = down_sampling_feats[:, :padded_target_len]
return mapped_feature
def get_valid_features(self, utt, content_feature):
# only keep effective parts
duration = utt["Duration"]
if self.extractor_type == "whisper":
frameshift = (
self.cfg.preprocess.whisper_frameshift
* self.cfg.preprocess.whisper_downsample_rate
) # 20ms
elif self.extractor_type == "contentvec":
frameshift = self.cfg.preprocess.contentvec_frameshift # 20ms
elif self.extractor_type == "wenet":
frameshift = (
self.cfg.preprocess.wenet_frameshift
* self.cfg.preprocess.wenet_downsample_rate
) # 40ms
elif self.extractor_type == "mert":
frameshift = self.cfg.preprocess.mert_frameshift
else:
raise NotImplementedError
# calculate the number of valid frames
num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1
assert (
len(content_feature.shape) == 2
), "content feature shape error, it should be (num_frames, dim)"
content_feature = content_feature[:num_frames, :]
return content_feature
def save_feature(self, utt, content_feature):
"""Save a single utternace to path {cfg.preprocess.processed_dir}
Args:
utt (dict): one item in metadata, containing information for one utterance
content_feature (tensor): content feature of one utterance
"""
uid = utt["Uid"]
assert self.extractor_type != None
out_dir = os.path.join(
self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type
)
os.makedirs(out_dir, exist_ok=True)
save_path = os.path.join(out_dir, uid + ".npy")
content_feature = self.get_valid_features(utt, content_feature)
np.save(save_path, content_feature.cpu().detach().numpy())
class WhisperExtractor(AudioPretrainedModelFeaturesExtractor):
def __init__(self, config):
super(WhisperExtractor, self).__init__(config, extractor_type="whisper")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model(self):
# load whisper checkpoint
print("Loading Whisper Model...")
if "whisper_model_path" in self.cfg.preprocess:
if os.path.isfile(self.cfg.preprocess.whisper_model_path):
# "pretrained/whisper/medium.pt"
download_root = os.path.dirname(self.cfg.preprocess.whisper_model_path)
elif os.path.isdir(self.cfg.preprocess.whisper_model_path):
# "pretrained/whisper"
download_root = self.cfg.preprocess.whisper_model_path
else:
# if the path does not exist, download the model to the path
download_root = self.cfg.preprocess.whisper_model_path
if download_root.endswith(".pt"):
download_root = os.path.dirname(download_root)
else:
download_root = None
model = whisper.load_model(
self.cfg.preprocess.whisper_model, self.device, download_root
)
if torch.cuda.is_available():
print("Using GPU...\n")
model = model.cuda()
else:
print("Using CPU...\n")
self.model = model.eval()
def extract_content_features(self, wavs):
"""extract content features from a batch of dataloader
Args:
wavs: tensor (batch_size, T)
"""
# wavs: (batch, max_len)
wavs = whisper.pad_or_trim(wavs)
# batch_mel: (batch, 80, 3000)
batch_mel = whisper.log_mel_spectrogram(wavs, device=self.model.device)
with torch.no_grad():
# (batch, 1500, 1024)
features = self.model.embed_audio(batch_mel)
return features
class ContentvecExtractor(AudioPretrainedModelFeaturesExtractor):
def __init__(self, cfg):
super(ContentvecExtractor, self).__init__(cfg, extractor_type="contentvec")
def load_model(self):
assert self.model == None
# Load model
ckpt_path = self.cfg.preprocess.contentvec_file
print("Load Contentvec Model...")
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
[ckpt_path],
suffix="",
)
model = models[0]
model.eval()
if torch.cuda.is_available():
# print("Using GPU...\n")
model = model.cuda()
self.model = model
def extract_content_features(self, wavs):
"""extract content features from a batch of dataloader
Args:
wavs: tensor (batch, T)
"""
device = next(self.model.parameters()).device
wavs = wavs.to(device) # (batch, max_len)
padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device)
with torch.no_grad():
logits = self.model.extract_features(
source=wavs, padding_mask=padding_mask, output_layer=12
)
# feats: (batch, T, 256)
feats = self.model.final_proj(logits[0])
return feats
class WenetExtractor(AudioPretrainedModelFeaturesExtractor):
def __init__(self, config):
super(WenetExtractor, self).__init__(config, extractor_type="wenet")
def load_model(self):
wenet_cfg = self.cfg.preprocess.wenet_config
wenet_model_path = self.cfg.preprocess.wenet_model_path
# load Wenet config
with open(wenet_cfg, "r") as w:
wenet_configs = yaml.load(w, Loader=yaml.FullLoader)
self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"])
print("Loading Wenet Model...")
self.model = init_model(wenet_configs)
load_checkpoint(self.model, wenet_model_path)
if torch.cuda.is_available():
print("Using GPU...\n")
self.model = self.model.cuda()
else:
print("Using CPU...\n")
self.model = self.model.eval()
def extract_content_features(self, wavs, lens):
"""extract content features from a batch of dataloader
Args:
wavs: tensor, whose shape is (B, T)
lens: list
"""
feats_list = []
lengths_list = []
device = next(self.model.parameters()).device
# Extract fbank/mfcc features by kaldi
assert self.extract_conf is not None, "load model first!"
feats_type = self.extract_conf.get("feats_type", "fbank")
assert feats_type in ["fbank", "mfcc"]
for idx, wav in enumerate(wavs):
# wav: (T)
wav = wav[: lens[idx]].to(device)
# pad one frame to compensate for the frame cut off after feature extraction
pad_tensor = torch.zeros(160, device=wav.device)
wav = torch.cat((wav, pad_tensor), dim=-1)
wav *= 1 << 15
wav = wav.unsqueeze(0) # (T) -> (1, T)
if feats_type == "fbank":
fbank_conf = self.extract_conf.get("fbank_conf", {})
feat = kaldi.fbank(
wav,
sample_frequency=16000,
num_mel_bins=fbank_conf["num_mel_bins"],
frame_length=fbank_conf["frame_length"],
frame_shift=fbank_conf["frame_shift"],
dither=fbank_conf["dither"],
)
elif feats_type == "mfcc":
mfcc_conf = self.extract_conf.get("mfcc", {})
feat = kaldi.mfcc(
wav,
sample_frequency=16000,
num_mel_bins=mfcc_conf["num_mel_bins"],
frame_length=mfcc_conf["frame_length"],
frame_shift=mfcc_conf["frame_shift"],
dither=mfcc_conf["dither"],
num_ceps=mfcc_conf.get("num_ceps", 40),
high_freq=mfcc_conf.get("high_freq", 0.0),
low_freq=mfcc_conf.get("low_freq", 20.0),
)
feats_list.append(feat)
lengths_list.append(feat.shape[0])
feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device)
feats_tensor = pad_sequence(feats_list, batch_first=True).to(
device
) # (batch, len, 80)
features = self.model.encoder_extractor(
feats_tensor,
feats_lengths,
decoding_chunk_size=-1,
num_decoding_left_chunks=-1,
simulate_streaming=False,
)
return features
class MertExtractor(AudioPretrainedModelFeaturesExtractor):
def __init__(self, cfg):
super(MertExtractor, self).__init__(cfg, extractor_type="mert")
self.preprocessor = None
def load_model(self):
assert self.model == None
assert self.preprocessor == None
print("Loading MERT Model: ...", self.cfg.preprocess.mert_model)
model_name = self.cfg.preprocess.mert_model
model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
if torch.cuda.is_available():
model = model.cuda()
preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(
model_name, trust_remote_code=True
)
self.model = model
self.preprocessor = preprocessor
def extract_content_features(self, wavs):
"""extract content features from a batch of dataloader
Args:
wavs: tensor (batch, T)
"""
with torch.no_grad():
sample_rate = self.preprocessor.sampling_rate
device = next(self.model.parameters()).device
assert (
sample_rate == self.cfg.preprocess.mert_sample_rate
), "mert sample rate mismatch, expected {}, got {}".format(
self.cfg.preprocess.mert_sample_rate, sample_rate
)
mert_features = []
# wav: (len)
for wav in wavs:
# {input_values: tensor, attention_mask: tensor}
inputs = self.preprocessor(
wavs, sampling_rate=sample_rate, return_tensors="pt"
).to(device)
outputs = self.model(**inputs, output_hidden_states=True)
# (25 layers, time steps, 1024 feature_dim)
all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
# (1, frame_len, 1024) -> (frame_len, 1024)
feature = outputs.hidden_states[
self.cfg.preprocess.mert_feature_layer
].squeeze(0)
mert_features.append(feature)
return mert_features
def extract_utt_content_features_dataloader(cfg, metadata, num_workers):
dataset_name = metadata[0]["Dataset"]
with torch.no_grad():
if cfg.preprocess.extract_whisper_feature:
feat_dir = os.path.join(
cfg.preprocess.processed_dir, dataset_name, "whisper"
)
os.makedirs(feat_dir, exist_ok=True)
feat_files_num = len(os.listdir(feat_dir))
if feat_files_num != len(metadata):
whisper_waveforms = FFmpegDataset(
cfg,
dataset_name,
cfg.preprocess.whisper_sample_rate,
metadata=metadata,
)
data_loader = DataLoader(
whisper_waveforms,
num_workers=num_workers,
shuffle=False,
pin_memory=cfg.preprocess.pin_memory,
batch_size=cfg.preprocess.content_feature_batch_size,
collate_fn=collate_batch,
drop_last=False,
)
extractor = WhisperExtractor(cfg)
extractor.load_model()
for batch_idx, items in enumerate(tqdm(data_loader)):
_metadata, wavs, lens = items
batch_content_features = extractor.extract_content_features(wavs)
for index, utt in enumerate(_metadata):
extractor.save_feature(utt, batch_content_features[index])
if cfg.preprocess.extract_contentvec_feature:
feat_dir = os.path.join(
cfg.preprocess.processed_dir, dataset_name, "contentvec"
)
os.makedirs(feat_dir, exist_ok=True)
feat_files_num = len(os.listdir(feat_dir))
if feat_files_num != len(metadata):
contentvec_waveforms = LibrosaDataset(
cfg,
dataset_name,
cfg.preprocess.contentvec_sample_rate,
metadata=metadata,
)
data_loader = DataLoader(
contentvec_waveforms,
num_workers=num_workers,
shuffle=False,
pin_memory=cfg.preprocess.pin_memory,
batch_size=cfg.preprocess.content_feature_batch_size,
collate_fn=collate_batch,
drop_last=False,
)
extractor = ContentvecExtractor(cfg)
extractor.load_model()
for batch_idx, items in enumerate(tqdm(data_loader)):
_metadata, wavs, lens = items
batch_content_features = extractor.extract_content_features(wavs)
for index, utt in enumerate(_metadata):
extractor.save_feature(utt, batch_content_features[index])
if cfg.preprocess.extract_wenet_feature:
feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet")
os.makedirs(feat_dir, exist_ok=True)
feat_files_num = len(os.listdir(feat_dir))
if feat_files_num != len(metadata):
wenet_waveforms = TorchaudioDataset(
cfg,
dataset_name,
cfg.preprocess.wenet_sample_rate,
metadata=metadata,
)
data_loader = DataLoader(
wenet_waveforms,
num_workers=num_workers,
shuffle=False,
pin_memory=cfg.preprocess.pin_memory,
batch_size=cfg.preprocess.content_feature_batch_size,
collate_fn=collate_batch,
drop_last=False,
)
extractor = WenetExtractor(cfg)
extractor.load_model()
for batch_idx, items in enumerate(tqdm(data_loader)):
_metadata, wavs, lens = items
batch_content_features = extractor.extract_content_features(
wavs,
lens,
)
for index, utt in enumerate(_metadata):
extractor.save_feature(utt, batch_content_features[index])
if cfg.preprocess.extract_mert_feature:
feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert")
os.makedirs(feat_dir, exist_ok=True)
feat_files_num = len(os.listdir(feat_dir))
if feat_files_num != len(metadata):
mert_waveforms = TorchaudioDataset(
cfg,
dataset_name,
cfg.preprocess.mert_sample_rate,
metadata=metadata,
)
data_loader = DataLoader(
mert_waveforms,
num_workers=num_workers,
shuffle=False,
pin_memory=cfg.preprocess.pin_memory,
batch_size=cfg.preprocess.content_feature_batch_size,
collate_fn=collate_batch,
drop_last=False,
)
extractor = MertExtractor(cfg)
extractor.load_model()
for batch_idx, items in enumerate(tqdm(data_loader)):
_metadata, wavs, lens = items
batch_content_features = extractor.extract_content_features(wavs)
for index, utt in enumerate(_metadata):
extractor.save_feature(utt, batch_content_features[index])