File size: 6,770 Bytes
b038b10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
from dataclasses import make_dataclass
import torch
import torchaudio
from torch import nn
from .usad_modules import ConformerEncoder
MAX_MEL_LENGTH = 3000 # 30 seconds
@torch.no_grad()
def wav_to_fbank(
wavs: torch.Tensor,
mel_dim: int = 128,
norm_mean: float = -4.268,
norm_std: float = 4.569,
) -> torch.Tensor:
"""Convert waveform to fbank features.
Args:
wavs (torch.Tensor): (B, T_wav) waveform tensor.
mel_dim (int, optional): mel dimension. Defaults to 128.
norm_mean (float, optional):
mean for normalization. Defaults to -4.268.
norm_std (float, optional):
std for normalization. Defaults to 4.569.
Returns:
torch.Tensor: (B, T_mel, mel_dim) fbank features.
"""
# ref: https://github.com/cwx-worst-one/EAT/tree/main/feature_extract
dtype = wavs.dtype
wavs = wavs.to(torch.float32)
wavs = wavs - wavs.mean(dim=-1, keepdim=True)
feats = [
torchaudio.compliance.kaldi.fbank(
wavs[i : i + 1],
htk_compat=True,
sample_frequency=16000,
use_energy=False,
window_type="hanning",
num_mel_bins=mel_dim,
dither=0.0,
frame_shift=10,
).to(dtype=dtype)
for i in range(wavs.shape[0])
]
mels = torch.stack(feats, dim=0)
mels = (mels - norm_mean) / (norm_std * 2)
return mels
class UsadModel(nn.Module):
def __init__(self, cfg) -> None:
"""Initialize the UsadModel.
Args:
cfg: Configuration object containing model parameters.
"""
super().__init__()
self.cfg = cfg
self.encoder = ConformerEncoder(cfg)
self.max_mel_length = MAX_MEL_LENGTH
# NOTE: The max_mel_length is set to 3000,
# which corresponds to 30 seconds of audio at 100 Hz frame rate.
@property
def sample_rate(self) -> int:
return 16000 # Hz
@property
def encoder_frame_rate(self) -> int:
return 50 # Hz
@property
def mel_dim(self) -> int:
return self.cfg.input_dim
@property
def encoder_dim(self) -> int:
return self.cfg.encoder_dim
@property
def num_layers(self) -> int:
return self.cfg.num_layers
@property
def scene_embedding_size(self) -> int:
return self.cfg.encoder_dim * self.cfg.num_layers
@property
def timestamp_embedding_size(self) -> int:
return self.cfg.encoder_dim * self.cfg.num_layers
@property
def device(self) -> torch.device:
"""Get the device on which the model is located."""
return next(self.parameters()).device
def set_audio_chunk_size(self, seconds: float = 30.0) -> None:
"""Set the maximum chunk size for feature extraction.
Args:
seconds (float, optional): Chunk size in seconds. Defaults to 30.0.
"""
assert (
seconds >= 0.1
), f"Chunk size must be greater than 0.1s, got {seconds} seconds."
self.max_mel_length = int(seconds * 100) # 100 Hz frame rate
def load_audio(self, audio_path: str) -> torch.Tensor:
"""Load audio file and return waveform tensor.
Args:
audio_path (str): Path to the audio file.
Returns:
torch.Tensor: Waveform tensor of shape (wav_len,).
"""
waveform, sr = torchaudio.load(audio_path)
if sr != self.sample_rate:
waveform = torchaudio.functional.resample(waveform, sr, self.sample_rate)
if waveform.shape[0] > 1:
# If stereo, convert to mono by averaging channels
waveform = waveform.mean(dim=0, keepdim=True)
waveform = waveform.squeeze(0) # Remove channel dimension if mono
return waveform.to(self.device) # Ensure tensor is on the same device
def forward(
self,
wavs: torch.Tensor,
norm_mean: float = -4.268,
norm_std: float = 4.569,
) -> dict:
"""Forward pass for the model.
Args:
wavs (torch.Tensor):
Input waveform tensor of shape (batch_size, wav_len).
norm_mean (float, optional):
Mean for normalization. Defaults to -4.268.
norm_std (float, optional):
Standard deviation for normalization. Defaults to 4.569.
Returns:
dict: A dictionary containing the model's outputs.
"""
# wavs: (batch_size, wav_len)
mel = wav_to_fbank(wavs, norm_mean=norm_mean, norm_std=norm_std)
mel = mel[:, : mel.shape[1] - mel.shape[1] % 2]
if mel.shape[1] <= self.max_mel_length:
x, x_len, layer_results = self.encoder(mel, return_hidden=True)
result = {
"x": x,
"mel": mel,
"hidden_states": layer_results["hidden_states"],
"ffn": layer_results["ffn_1"],
}
return result
result = {
"x": [],
"mel": mel,
"hidden_states": [[] for _ in range(self.cfg.num_layers)],
"ffn": [[] for _ in range(self.cfg.num_layers)],
}
for i in range(0, mel.shape[1], self.max_mel_length):
if mel.shape[1] - i < 10:
break
x, x_len, layer_results = self.encoder(
mel[:, i : i + self.max_mel_length], return_hidden=True
)
result["x"].append(x)
for j in range(self.cfg.num_layers):
result["hidden_states"][j].append(layer_results["hidden_states"][j])
result["ffn"][j].append(layer_results["ffn_1"][j])
result["x"] = torch.cat(result["x"], dim=1)
for j in range(self.cfg.num_layers):
result["hidden_states"][j] = torch.cat(result["hidden_states"][j], dim=1)
result["ffn"][j] = torch.cat(result["ffn"][j], dim=1)
# result["x"]: model final output (batch_size, seq_len)
# result["mel"]: mel fbank (batch_size, seq_len * 2, mel_dim)
# result["hidden_states"]: List of (batch_size, seq_len, encoder_dim)
# result["ffn"]: List of (batch_size, seq_len, encoder_dim)
return result
@classmethod
def load_from_fairseq_ckpt(cls, ckpt_path: str):
checkpoint = torch.load(ckpt_path, weights_only=False)
config = checkpoint["cfg"]["model"]
config = make_dataclass("Config", config.keys())(**config)
model = cls(config)
state_dict = checkpoint["model"]
for k in list(state_dict.keys()):
if not k.startswith("encoder."):
del state_dict[k]
model.load_state_dict(state_dict, strict=True)
return model
|