import datasets |
import fairseq |
import torch |
import os |
import soundfile as sf |
from datasets import load_dataset |
import sys |
from shutil import copyfile |
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor |
hf_path = str(sys.argv[1]) |
fairseq_wav2vec2_path = str(sys.argv[2]) |
finetuned = bool(int(sys.argv[3])) |
if finetuned: |
processor = Wav2Vec2Processor.from_pretrained(hf_path) |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( |
[fairseq_wav2vec2_path], arg_overrides={"data": "../add_wav2vec/data/temp"} |
) |
hf_model = Wav2Vec2ForCTC.from_pretrained(hf_path) |
else: |
processor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path) |
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path]) |
hf_model = Wav2Vec2Model.from_pretrained(hf_path) |
model = model[0] |
model.eval() |
def test_feature_extractor(hf_feat_extractor, fsq_feat_extract, example_wav): |
fsq_output = fsq_feat_extract(example_wav) |
hf_output = hf_feat_extractor(example_wav) |
assert ( |
hf_output.shape == fsq_output.shape |
), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq" |
assert torch.allclose(hf_output, fsq_output, atol=1e-3) |
def test_full_encoder(hf_model, fsq_model, example_wav, attention_mask): |
fsq_output = fsq_model(example_wav, padding_mask=attention_mask.ne(1), mask=False, features_only=True)["x"] |
hf_output = hf_model(example_wav, attention_mask=attention_mask)[0] |
assert ( |
hf_output.shape == fsq_output.shape |
), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq" |
assert torch.allclose(hf_output, fsq_output, atol=1e-2) |
def test_full_model(hf_model, fsq_model, example_wav, attention_mask): |
fsq_output = fsq_model(source=example_wav, padding_mask=attention_mask.ne(1))["encoder_out"] |
hf_output = hf_model(example_wav, attention_mask=attention_mask)[0].transpose(0, 1) |
assert ( |
hf_output.shape == fsq_output.shape |
), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq" |
assert torch.allclose(hf_output, fsq_output, atol=1e-2) |
def test_loss(hf_model, fsq_model, example_wav, attention_mask, target): |
from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig |
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask |
audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data") |
task = AudioPretrainingTask.setup_task(audio_cfg) |
ctc = CtcCriterion(CtcCriterionConfig(), task) |
fsq_model.train() |
labels_dict = processor.tokenizer(target, padding="longest", return_tensors="pt") |
labels = labels_dict.input_ids |
target_lengths = labels_dict.attention_mask.sum(-1) |
sample = { |
"net_input": { |
"source": example_wav, |
"padding_mask": attention_mask.ne(1), |
}, |
"target": labels, |
"target_lengths": target_lengths, |
"id": torch.zeros((1,)), |
} |
loss, _, _ = ctc(fsq_model, sample) |
labels = labels_dict.attention_mask * labels + (1 - labels_dict.attention_mask) * -100 |
hf_model.config.ctc_loss_reduction = "mean" |
hf_loss = hf_model(example_wav, attention_mask=attention_mask, labels=labels).loss |
print("Loss", loss) |
print("Hf loss", hf_loss) |
def test_all(example_wav, attention_mask): |
with torch.no_grad(): |
if finetuned: |
test_feature_extractor( |
hf_model.wav2vec2.feature_extractor, model.w2v_encoder.w2v_model.feature_extractor, example_wav |
) |
else: |
test_feature_extractor( |
hf_model.feature_extractor, model.feature_extractor, example_wav |
) |
print("Succeded feature extractor Test") |
with torch.no_grad(): |
if finetuned: |
test_full_encoder(hf_model.wav2vec2, model.w2v_encoder.w2v_model, example_wav, attention_mask) |
else: |
test_full_encoder(hf_model, model, example_wav, attention_mask) |
print("Succeded full encoder test") |
if finetuned: |
with torch.no_grad(): |
test_full_model(hf_model, model, example_wav, attention_mask) |
print("Succeded full model test") |
dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") |
def map_to_array(batch): |
speech_array, _ = sf.read(batch["file"]) |
batch["speech"] = speech_array |
return batch |
def map_to_array_mp3(batch, i): |
speech_array, sr = sf.read(f"/home/patrick/hugging_face/add_wav2vec/common_voice/cv-corpus-6.1-2020-12-11/nl/converted/sample_{i}.wav") |
batch["speech"] = speech_array |
batch["sampling_rate"] = sr |
return batch |
dummy_speech_data = dummy_speech_data.map(map_to_array, remove_columns=["file"]) |
inputs = processor(dummy_speech_data[:3]["speech"], return_tensors="pt", padding="longest", return_attention_mask=True) |
transciption = dummy_speech_data[:3]["text"] |
input_values = inputs.input_values |
attention_mask = inputs.attention_mask |
test_all(input_values, attention_mask) |