File size: 11,339 Bytes
b386992 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is designed to extract features from different layers of a pretrained SSL model.
The extracted features will be in *.npy format, and in the shape of [L, D, T], where L is the
number of layers, D is the feature dimension, and T is the time dimension.
Example usage:
python extract_features.py \
--model_path="nvidia/ssl_en_nest_large_v1.0" \
--input=<path to input manifest, or a dir containing audios, or path to audio> \
--output=<output directory to store features and manifest> \
--layers="all" \
--batch_size=8 \
--workers=8 \
--max_cache=1000 # save features every 1000 samples to avoid OOM in system memory
"""
import argparse
import os
import tempfile
from pathlib import Path
from typing import List
import lightning.pytorch as pl
import numpy as np
import torch
from tqdm import tqdm
from nemo.collections.asr.data.audio_to_text_dataset import get_char_dataset
from nemo.collections.asr.models import EncDecDenoiseMaskedTokenPredModel
from nemo.collections.asr.modules import ConformerMultiLayerFeatureExtractor
from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest
from nemo.collections.common.data.utils import move_data_to_device
from nemo.collections.common.parts.preprocessing.manifest import get_full_path
from nemo.core.classes.common import typecheck
from nemo.utils import logging
typecheck.set_typecheck_enabled(enabled=False)
parser = argparse.ArgumentParser(description="Extract audio features using an SSL model")
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the .nemo model file or a pretrained model name from the NGC/HF model hub",
)
parser.add_argument(
"-i",
"--input",
type=str,
required=True,
help="Path to the input audio file, or list of files, directory or jsonl manifest",
)
parser.add_argument(
"-o", "--output", type=str, required=True, help="Path to the output directory that contains .npy file"
)
parser.add_argument(
"-l",
"--layers",
type=str,
default="all",
help="Layers to extract features from, use 'all' to extract from all layer, 'last' for last layer, "
"or comma-separated indices of the target layers (e.g. '0,1,2')",
)
parser.add_argument("-b", "--batch_size", type=int, default=8, help="Batch size for feature extraction")
parser.add_argument("-w", "--workers", type=int, default=8, help="Number of workers for feature extraction")
parser.add_argument("-d", "--device", type=str, default="cuda", help="Device to use for feature extraction")
parser.add_argument("-t", "--type", type=str, default="wav", help="audio file type, only needed for directory input")
parser.add_argument("--use_amp", action="store_true", help="Use automatic mixed precision")
parser.add_argument(
"--amp_dtype",
type=str,
default="float16",
choices=["float16", "bfloat16"],
help="Data type for automatic mixed precision",
)
parser.add_argument("-mc", "--max_cache", type=int, default=-1, help="Max cache size before saving features")
args = parser.parse_args()
def get_input_manifest(input: str) -> List[dict]:
"""
Build manifest from input path or directory
"""
if input.endswith(".json") or input.endswith(".jsonl") and os.path.isfile(input):
logging.info(f"Reading manifest from: {input}")
manifest = [
{"audio_filepath": str(get_full_path(item["audio_filepath"], input)), "duration": None, "text": "-"}
for item in read_manifest(input)
]
elif os.path.isdir(input):
logging.info(f"Creating manifest from directory: {input}")
manifest = [
{"audio_filepath": str(p), "duration": None, "text": "-"} for p in Path(input).rglob(f"*.{args.type}")
]
logging.info(f"Found {len(manifest)} items of {args.type} files")
elif os.path.isfile(input):
logging.info(f"Reading single file: {input}")
manifest = [{"audio_filepath": Path(input).absolute.as_posix(), "duration": None, "text": "-"}]
else:
raise ValueError(f"Invalid input: {input}")
return manifest
def load_model(model_path):
"""
Load SSL model from local or pretrained
"""
if model_path.endswith(".nemo") and os.path.isfile(model_path):
logging.info(f"Loading model from local: {model_path}")
model = EncDecDenoiseMaskedTokenPredModel.restore_from(model_path)
else:
logging.info(f"Loading model from pretrained: {model_path}")
model = EncDecDenoiseMaskedTokenPredModel.from_pretrained(model_name=model_path)
return model
class FeatureExtractor(pl.LightningModule):
"""
Wrapper class for extracting features from SSL model
"""
def __init__(self, ssl_model: EncDecDenoiseMaskedTokenPredModel, layer: str = "all"):
super().__init__()
self.preprocessor = ssl_model.preprocessor
self.encoder = ssl_model.encoder
self.layer_idx_list = None
self.sample_rate = ssl_model.cfg.sample_rate
if layer == "all":
self.layer_idx_list = None
elif layer == "last":
self.layer_idx_list = [len(self.encoder.layers) - 1]
else:
try:
self.layer_idx_list = [int(l) for l in layer.split(",")]
except Exception as e:
raise ValueError(f"Invalid layer argument: {layer}. Error: {e}")
self.feature_extractor = ConformerMultiLayerFeatureExtractor(
self.encoder, aggregator=None, layer_idx_list=self.layer_idx_list
)
def forward(
self,
input_signal=None,
input_signal_length=None,
processed_signal=None,
processed_signal_length=None,
):
"""
Forward pass to extract features, same input interface as EncDecDenoiseMaskedTokenPredModel.forward
"""
has_input_signal = input_signal is not None and input_signal_length is not None
has_processed_signal = processed_signal is not None and processed_signal_length is not None
if (has_input_signal ^ has_processed_signal) == False:
raise ValueError(
f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive "
" with ``processed_signal`` and ``processed_signal_len`` arguments."
)
if not has_processed_signal:
processed_signal, processed_signal_length = self.preprocessor(
input_signal=input_signal,
length=input_signal_length,
)
encoded, encoded_len = self.feature_extractor(audio_signal=processed_signal, length=processed_signal_length)
return encoded, encoded_len
def maybe_save_features(output_dir, results, max_cache, manifest):
"""
Check if the cache is full and save features to disk
"""
if len(results) == 0 or max_cache < 0 or len(results) < max_cache:
return
os.makedirs(output_dir, exist_ok=True)
logging.info(f"Saving {len(results)} features to {output_dir}")
for sample_id, audio_file, features_np in tqdm(results, desc="Saving features", total=len(results)):
filename = str(audio_file).replace("/", "_").replace(".", "_")
if len(filename) > 256:
filename = filename[-256:]
output_path = os.path.join(output_dir, f"{filename}.npy")
np.save(output_path, features_np)
manifest[sample_id]["feature_path"] = output_path
logging.info(f"Saved {len(results)} features to {output_dir}")
results.clear()
def extract_features(args):
"""
Main function to extract and save features from SSL model
"""
logging.info(f"Extracting features using params: {vars(args)}")
# Load model
model = load_model(args.model_path)
feature_extractor = FeatureExtractor(model, args.layers)
device = torch.device(args.device)
feature_extractor.to(device)
# Load data
logging.info(f"Building dataset from input: {args.input}")
tmp_manifest = tempfile.NamedTemporaryFile(mode="w", delete=False)
manifest = get_input_manifest(args.input)
write_manifest(tmp_manifest.name, manifest)
total_num_samples = len(manifest)
# Build dataloader
config = {
"manifest_filepath": tmp_manifest.name,
"sample_rate": feature_extractor.sample_rate,
"return_sample_id": True,
}
dataset = get_char_dataset(config)
logging.info(f"Built dataset with {len(dataset)} samples")
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
collate_fn=dataset.collate_fn,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
drop_last=False,
)
# Extract features
indices = set()
results = []
amp_dtype = torch.float16 if args.amp_dtype == "float16" else torch.bfloat16
logging.info(f"Extracting features using AMP: {args.use_amp}, dtype: {amp_dtype}")
with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=amp_dtype, enabled=args.use_amp):
with torch.inference_mode():
for batch in tqdm(dataloader, desc="Extracting features"):
batch = move_data_to_device(batch, device)
audio_signal, audio_signal_len, _, _, sample_id = batch
features, features_len = feature_extractor(
input_signal=audio_signal, input_signal_length=audio_signal_len
)
batch_size = features[0].size(0)
num_layers = len(features)
for i in range(batch_size):
sid_i = sample_id[i]
if sid_i in indices:
logging.warning(f"Skipping duplicated sample_id: {sample_id}")
continue
feat_i_len = features_len[0][i]
feat_i = []
for j in range(num_layers):
feat_i.append(features[j][i][:, :feat_i_len])
feat_i_np = torch.stack(feat_i, dim=0).cpu().numpy()
indices.add(sid_i)
results.append((sid_i, manifest[sid_i]['audio_filepath'], feat_i_np))
maybe_save_features(args.output, results, args.max_cache, manifest)
maybe_save_features(args.output, results, 0, manifest)
output_manifest = Path(args.output) / "features.json"
write_manifest(output_manifest, manifest)
os.remove(tmp_manifest.name)
logging.info(f"Extracted features from {total_num_samples} samples to {args.output}")
logging.info(f"Manifest saved to: {output_manifest}")
if __name__ == "__main__":
extract_features(args)
|