File size: 5,701 Bytes
9515498
 
 
 
 
 
 
 
 
 
 
 
14441d6
9515498
14441d6
9515498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
import datasets
import fairseq
import torch
import os

import soundfile as sf
from datasets import load_dataset
import sys
from shutil import copyfile
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor

hf_path = str(sys.argv[1])
fairseq_wav2vec2_path = str(sys.argv[2])
finetuned = bool(int(sys.argv[3]))


if finetuned:
    processor = Wav2Vec2Processor.from_pretrained(hf_path)
    model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
        [fairseq_wav2vec2_path], arg_overrides={"data": "../add_wav2vec/data/temp"}
    )
    hf_model = Wav2Vec2ForCTC.from_pretrained(hf_path)
else:
    processor = Wav2Vec2FeatureExtractor.from_pretrained(hf_path)
    model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_wav2vec2_path])
    hf_model = Wav2Vec2Model.from_pretrained(hf_path)

model = model[0]
model.eval()


def test_feature_extractor(hf_feat_extractor, fsq_feat_extract, example_wav):
    # set hf_feat_extractor.output to dummy
    fsq_output = fsq_feat_extract(example_wav)
    hf_output = hf_feat_extractor(example_wav)

    assert (
        hf_output.shape == fsq_output.shape
    ), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq"
    assert torch.allclose(hf_output, fsq_output, atol=1e-3)


def test_full_encoder(hf_model, fsq_model, example_wav, attention_mask):
    fsq_output = fsq_model(example_wav, padding_mask=attention_mask.ne(1), mask=False, features_only=True)["x"]
    hf_output = hf_model(example_wav, attention_mask=attention_mask)[0]

    assert (
        hf_output.shape == fsq_output.shape
    ), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq"
    assert torch.allclose(hf_output, fsq_output, atol=1e-2)


def test_full_model(hf_model, fsq_model, example_wav, attention_mask):
    fsq_output = fsq_model(source=example_wav, padding_mask=attention_mask.ne(1))["encoder_out"]
    hf_output = hf_model(example_wav, attention_mask=attention_mask)[0].transpose(0, 1)

    assert (
        hf_output.shape == fsq_output.shape
    ), f"Shapes don't match. Got {hf_output.shape} for HF and {fsq_output.shape} for fsq"
    assert torch.allclose(hf_output, fsq_output, atol=1e-2)


def test_loss(hf_model, fsq_model, example_wav, attention_mask, target):
    from fairseq.criterions.ctc import CtcCriterion, CtcCriterionConfig
    from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask
    audio_cfg = AudioPretrainingConfig(labels="ltr", data="./data")
    task = AudioPretrainingTask.setup_task(audio_cfg)
    ctc = CtcCriterion(CtcCriterionConfig(), task)
    fsq_model.train()

    labels_dict = processor.tokenizer(target, padding="longest", return_tensors="pt")
    labels = labels_dict.input_ids
    target_lengths = labels_dict.attention_mask.sum(-1)

    sample = {
        "net_input": {
            "source": example_wav,
            "padding_mask": attention_mask.ne(1),
        },
        "target": labels,
        "target_lengths": target_lengths,
        "id": torch.zeros((1,)),
    }

    loss, _, _ = ctc(fsq_model, sample)

    labels = labels_dict.attention_mask * labels + (1 - labels_dict.attention_mask) * -100

    hf_model.config.ctc_loss_reduction = "mean"
    hf_loss = hf_model(example_wav, attention_mask=attention_mask, labels=labels).loss

    print("Loss", loss)
    print("Hf loss", hf_loss)


def test_all(example_wav, attention_mask):
    with torch.no_grad():
        if finetuned:
            test_feature_extractor(
                hf_model.wav2vec2.feature_extractor, model.w2v_encoder.w2v_model.feature_extractor, example_wav
            )
        else:
            test_feature_extractor(
                hf_model.feature_extractor, model.feature_extractor, example_wav
            )
    print("Succeded feature extractor Test")

    with torch.no_grad():
        # IMPORTANT: It is assumed that layer_norm_first is FALSE
        # This is the case for `wav2vec_small_960h.pt`, but might not be for all models
        # Adapt if necessary
        if finetuned:
            test_full_encoder(hf_model.wav2vec2, model.w2v_encoder.w2v_model, example_wav, attention_mask)
        else:
            test_full_encoder(hf_model, model, example_wav, attention_mask)
    print("Succeded full encoder test")

    if finetuned:
        with torch.no_grad():
            # IMPORTANT: It is assumed that layer_norm_first is FALSE
            # This is the case for `wav2vec_small_960h.pt`, but might not be for all models
            # Adapt if necessary
            test_full_model(hf_model, model, example_wav, attention_mask)
        print("Succeded full model test")


dummy_speech_data = datasets.load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")


def map_to_array(batch):
    speech_array, _ = sf.read(batch["file"])
    batch["speech"] = speech_array
    return batch

def map_to_array_mp3(batch, i):
    speech_array, sr = sf.read(f"/home/patrick/hugging_face/add_wav2vec/common_voice/cv-corpus-6.1-2020-12-11/nl/converted/sample_{i}.wav")
    batch["speech"] = speech_array
    batch["sampling_rate"] = sr
    return batch


dummy_speech_data = dummy_speech_data.map(map_to_array, remove_columns=["file"])
inputs = processor(dummy_speech_data[:3]["speech"], return_tensors="pt", padding="longest", return_attention_mask=True)

transciption = dummy_speech_data[:3]["text"]

input_values = inputs.input_values
attention_mask = inputs.attention_mask

test_all(input_values, attention_mask)
#test_loss(hf_model, model, input_values, attention_mask, transciption)