|
|
|
import torch |
|
import torchaudio |
|
import torch.nn as nn |
|
|
|
import modules.p1cupe.model_utils as model_utils |
|
|
|
|
|
|
|
class FinetuneXLSR(nn.Module): |
|
def __init__(self, hp, input_wav_length, freeze_feature_encoder=False): |
|
super().__init__() |
|
|
|
self.hp = hp |
|
self.noise_level = hp.noise_level |
|
self.xls_dim = 1024 |
|
self.output_dim = hp.phoneme_classes + 1 |
|
|
|
|
|
|
|
bundle = torchaudio.pipelines.WAV2VEC2_XLSR_300M |
|
|
|
self.XLSR = bundle.get_model() |
|
|
|
|
|
|
|
|
|
def reset_parameters(module): |
|
if hasattr(module, 'reset_parameters'): |
|
module.reset_parameters() |
|
|
|
|
|
if hasattr(self.hp , 'xlrs_reset'): |
|
if (self.hp.xlrs_reset): |
|
print("reset_parameters for XLSR") |
|
self.XLSR.apply(reset_parameters) |
|
|
|
self.freeze_feature_encoder = freeze_feature_encoder |
|
|
|
|
|
if self.freeze_feature_encoder: |
|
|
|
for param in self.XLSR.model.feature_extractor.parameters(): |
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
nn.Linear(self.xls_dim, self.xls_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.25), |
|
nn.Linear(self.xls_dim, self.output_dim) |
|
) |
|
|
|
|
|
|
|
self.layer_dims = model_utils.ModelUtils.extract_layer_dims(self) |
|
self.frames_per_window = model_utils.ModelUtils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int() |
|
self.frames_per_window = torch.ceil((self.frames_per_window-1)).int() |
|
self.model_utils = model_utils.ModelUtils(self.layer_dims, input_wav_length, self.frames_per_window) |
|
|
|
|
|
|
|
|
|
def update_frames_per_window(self, input_wav_length): |
|
self.frames_per_window = self.model_utils.calculate_layer_sizes(self.layer_dims, torch.tensor([input_wav_length]), -1)[0].int() |
|
self.frames_per_window = torch.ceil((self.frames_per_window-1)).int() |
|
print("frames_per_window (frames per clip if disable_windowing):", self.frames_per_window.item()) |
|
return self.frames_per_window |
|
|
|
def forward(self, x): |
|
if self.training and self.noise_level > 0: |
|
x = x + torch.randn_like(x) * self.noise_level |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features, _ = self.XLSR.extract_features(x) |
|
features = features[-1] |
|
|
|
logits = self.classifier(features) |
|
return logits |
|
|
|
|
|
|
|
|
|
def test_model(): |
|
|
|
|
|
|
|
|
|
|
|
device = torch.device("cuda:0") |
|
|
|
print(torch.__version__) |
|
print(torchaudio.__version__) |
|
torch.random.manual_seed(0) |
|
print(device) |
|
|
|
|
|
model = FinetuneXLSR(noise_level=0.01, freeze_feature_encoder=False).to(device) |
|
|
|
print(model.__class__) |
|
|
|
|
|
sample_path = "tmp/data/audio_samples/9860_8338_000010.flac.wav" |
|
|
|
sample_waveform, sample_samplerate = torchaudio.load(sample_path) |
|
sample_waveform = sample_waveform.to(device) |
|
waveform = torchaudio.functional.resample(sample_waveform, sample_samplerate, 16000) |
|
|
|
batch = waveform.to(device) |
|
print(waveform.shape) |
|
|
|
|
|
with torch.inference_mode(): |
|
logits = model(batch) |
|
|
|
print(len(logits), logits[0].shape) |
|
for x, element in enumerate(logits): |
|
print(x, element) |
|
|
|
def main(): |
|
test_model() |
|
|
|
if __name__ == "__main__": |
|
main() |