|
|
|
import datasets |
|
import fairseq |
|
import torch |
|
import os |
|
|
|
import soundfile as sf |
|
from datasets import load_dataset |
|
import sys |
|
from shutil import copyfile |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor |
|
|
|
hf_path = str(sys.argv[1]) |
|
fairseq_wav2vec2_path = str(sys.argv[2]) |
|
finetuned = bool(int(sys.argv[3])) |
|
|
|
|
|
if finetuned: |
|
processor = Wav2Vec2Processor.from_pretrained(hf_path) |
|
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( |
|
[fairseq_wav2vec2_path], arg_overrides={"data": "../add_wav2vec/data/temp"} |
|
) |
|
hf_model = Wav2Vec2ForCTC.from_pretrained(hf_path) |
|
else: |
|
processor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path) |
|
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path]) |
|
hf_model = Wav2Vec2Model.from_pretrained(hf_path) |
|
|
|
model = model[0] |
|
model.eval() |
|
|
|
|
|
def test_feature_extractor(hf_feat_extractor, fsq_feat_extract, example_wav): |
|
|
|
fsq_output = fsq_feat_extract(example_wav) |
|
hf_output = hf_feat_extractor(example_wav) |
|
|
|
assert ( |
|
hf_output.shape == fsq_output.shape |
|
), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq" |
|
assert torch.allclose(hf_output, fsq_output, atol=1e-3) |
|
|
|
|
|
def test_full_encoder(hf_model, fsq_model, example_wav, attention_mask): |
|
fsq_output = fsq_model(example_wav, padding_mask=attention_mask.ne(1), mask=False, features_only=True)["x"] |
|
hf_output = hf_model(example_wav, attention_mask=attention_mask)[0] |
|
|
|
assert ( |
|
hf_output.shape == fsq_output.shape |
|
), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq" |
|
assert torch.allclose(hf_output, fsq_output, atol=1e-2) |
|
|
|
|
|
def test_full_model(hf_model, fsq_model, example_wav, attention_mask): |
|
fsq_output = fsq_model(source=example_wav, padding_mask=attention_mask.ne(1))["encoder_out"] |
|
hf_output = hf_model(example_wav, attention_mask=attention_mask)[0].transpose(0, 1) |
|
|
|
assert ( |
|
hf_output.shape == fsq_output.shape |
|
), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq" |
|
assert torch.allclose(hf_output, fsq_output, atol=1e-2) |
|
|
|
|
|
def test_loss(hf_model, fsq_model, example_wav, attention_mask, target): |
|
from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig |
|
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask |
|
audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data") |
|
task = AudioPretrainingTask.setup_task(audio_cfg) |
|
ctc = CtcCriterion(CtcCriterionConfig(), task) |
|
fsq_model.train() |
|
|
|
labels_dict = processor.tokenizer(target, padding="longest", return_tensors="pt") |
|
labels = labels_dict.input_ids |
|
target_lengths = labels_dict.attention_mask.sum(-1) |
|
|
|
sample = { |
|
"net_input": { |
|
"source": example_wav, |
|
"padding_mask": attention_mask.ne(1), |
|
}, |
|
"target": labels, |
|
"target_lengths": target_lengths, |
|
"id": torch.zeros((1,)), |
|
} |
|
|
|
loss, _, _ = ctc(fsq_model, sample) |
|
|
|
labels = labels_dict.attention_mask * labels + (1 - labels_dict.attention_mask) * -100 |
|
|
|
hf_model.config.ctc_loss_reduction = "mean" |
|
hf_loss = hf_model(example_wav, attention_mask=attention_mask, labels=labels).loss |
|
|
|
print("Loss", loss) |
|
print("Hf loss", hf_loss) |
|
|
|
|
|
def test_all(example_wav, attention_mask): |
|
with torch.no_grad(): |
|
if finetuned: |
|
test_feature_extractor( |
|
hf_model.wav2vec2.feature_extractor, model.w2v_encoder.w2v_model.feature_extractor, example_wav |
|
) |
|
else: |
|
test_feature_extractor( |
|
hf_model.feature_extractor, model.feature_extractor, example_wav |
|
) |
|
print("Succeded feature extractor Test") |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
if finetuned: |
|
test_full_encoder(hf_model.wav2vec2, model.w2v_encoder.w2v_model, example_wav, attention_mask) |
|
else: |
|
test_full_encoder(hf_model, model, example_wav, attention_mask) |
|
print("Succeded full encoder test") |
|
|
|
if finetuned: |
|
with torch.no_grad(): |
|
|
|
|
|
|
|
test_full_model(hf_model, model, example_wav, attention_mask) |
|
print("Succeded full model test") |
|
|
|
|
|
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") |
|
|
|
|
|
def map_to_array(batch): |
|
speech_array, _ = sf.read(batch["file"]) |
|
batch["speech"] = speech_array |
|
return batch |
|
|
|
def map_to_array_mp3(batch, i): |
|
speech_array, sr = sf.read(f"/home/patrick/hugging_face/add_wav2vec/common_voice/cv-corpus-6.1-2020-12-11/nl/converted/sample_{i}.wav") |
|
batch["speech"] = speech_array |
|
batch["sampling_rate"] = sr |
|
return batch |
|
|
|
|
|
dummy_speech_data = dummy_speech_data.map(map_to_array, remove_columns=["file"]) |
|
inputs = processor(dummy_speech_data[:3]["speech"], return_tensors="pt", padding="longest", return_attention_mask=True) |
|
|
|
transciption = dummy_speech_data[:3]["text"] |
|
|
|
input_values = inputs.input_values |
|
attention_mask = inputs.attention_mask |
|
|
|
test_all(input_values, attention_mask) |
|
|
|
|