Wav2vec2 finetuning - Evaluation WER does not change

#1
by thardindubph200 - opened

I'm finetuning Wav2vec2 on a downstream task. Before finetuning, I checked the accuracy of pre-trained one comes in transformers library and it gives okay results. But when I finetune, the accuracy becomes Zero. I must be doing something wrong. Can someone help me figure out?

Screenshot from 2022-07-08 07-50-12.png
image.png

##******* code ********
from datasets import load_dataset, load_metric
from create_dataset import CreateAndLoadDataset
from transformers import Wav2Vec2CTCTokenizer, Wav2Vec2FeatureExtractor, Wav2Vec2Processor, Wav2Vec2ForCTC
from transformers import TrainingArguments, Trainer
from datasets import ClassLabel
import random
import pandas as pd
import json
import time
import sounddevice
import numpy as np
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor (:class:~transformers.Wav2Vec2Processor)
The processor used for proccessing the data.
padding (:obj:bool, :obj:str or :class:~transformers.tokenization_utils_base.PaddingStrategy, optional, defaults to :obj:True):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:True or :obj:'longest': Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:'max_length': Pad to a maximum length specified with the argument :obj:max_length or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:False or :obj:'do_not_pad' (default): No padding (i.e., can output a batch with sequences of
different lengths).
"""

processor: Wav2Vec2Processor
padding: Union[bool, str] = True

def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
    # split inputs and labels since they have to be of different lenghts and need
    # different padding methods
    input_features = [{"input_values": feature["input_values"]} for feature in features]
    label_features = [{"input_ids": feature["labels"]} for feature in features]

    batch = self.processor.pad(
        input_features,
        padding=self.padding,
        return_tensors="pt",
    )
    with self.processor.as_target_processor():
        labels_batch = self.processor.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt",
        )

    # replace padding with -100 to ignore loss correctly
    labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

    batch["labels"] = labels

    return batch

def show_random_elements(dataset, num_examples=10):
assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)

df = pd.DataFrame(dataset[picks])
print(df)

def extract_all_chars(batch):
all_text = " ".join(batch["transcript"])
vocab = list(set(all_text))
return {"vocab": [vocab], "all_text": [all_text]}

def prepare_dataset(batch):
audio = batch["audio_filepath"]

# batched output is "un-batched" to ensure mapping is correct
batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
batch["input_length"] = len(batch["input_values"])

with processor.as_target_processor():
    batch["labels"] = processor(batch["transcript"]).input_ids
return batch

def compute_metrics(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)

pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

wer = wer_metric.compute(predictions=pred_str, references=label_str)

return {"wer": wer}

tk_audio_dataset = CreateAndLoadDataset()
print(tk_audio_dataset)
print(len(tk_audio_dataset["train"]))

show_random_elements(tk_audio_dataset["train"], num_examples=10)
vocabs = tk_audio_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=tk_audio_dataset.column_names["train"])

vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
print(vocab_dict)

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
print(len(vocab_dict))

with open('vocab.json', 'w') as vocab_file:
json.dump(vocab_dict, vocab_file)

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

tokenizer = Wav2Vec2Processor.from_pretrained("/home/tharindu/Desktop/black/codes/BPH/wav2vec2/model_git/wav2vec2-base-960h")

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

print(tk_audio_dataset['train'][0]["audio_filepath"])

Play a random data file to check

rand_int = random.randint(0, len(tk_audio_dataset["train"]))
print(tk_audio_dataset["train"][rand_int]["transcript"])
temp_sound_example = np.asarray(tk_audio_dataset["train"][rand_int]["audio_filepath"]["array"])
print(temp_sound_example)
sounddevice.play(temp_sound_example, 16000) # releases GIL
time.sleep(5)
######

Check details of a random file

rand_int = random.randint(0, len(tk_audio_dataset["train"]))
print("Target text:", tk_audio_dataset["train"][rand_int]["transcript"])
print("Input array shape:", np.asarray(tk_audio_dataset["train"][rand_int]["audio_filepath"]["array"]).shape)
print("Sampling rate:", tk_audio_dataset["train"][rand_int]["audio_filepath"]["sampling_rate"])
######

#the following line and the function is calls may be wrong. Need to rethink of having 3 columns in the dataset. accoording to the Colab instructions, resampling use 'audio' column and not the 'audio_filepath' column.
tk_audio_dataset = tk_audio_dataset.map(prepare_dataset, remove_columns=tk_audio_dataset.column_names["train"], num_proc=4)

max_input_length_in_sec = 4.0
tk_audio_dataset["train"] = tk_audio_dataset["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
wer_metric = load_metric("wer")

model = Wav2Vec2ForCTC.from_pretrained(
"/home/tharindu/Desktop/black/codes/BPH/wav2vec2/model_git_copy/wav2vec2-base-960h",
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
)

print(model)

model.freeze_feature_encoder()

training_args = TrainingArguments(
output_dir="/home/tharindu/Desktop/black/codes/BPH/wav2vec2/model_git_copy/wav2vec2-base-960h",
group_by_length=True,
per_device_train_batch_size=8,
evaluation_strategy="steps",
num_train_epochs=50,
fp16=True,
gradient_checkpointing=True,
save_steps=500,
eval_steps=100, #originally it was 500
logging_steps=100, #originally it was 500
learning_rate=1e-4,
weight_decay=0.005,
warmup_steps=100, #original it was 1000
save_total_limit=2,
push_to_hub=False, #uncommnet later if necessary
)

trainer = Trainer(
model=model,
data_collator=data_collator,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=tk_audio_dataset["train"],
eval_dataset=tk_audio_dataset["test"],
tokenizer=processor.feature_extractor,
)

trainer.train()
trainer.save_model()

thardindubph200 changed discussion title from Wav2vec2 finetuning - Evaluation WER does not go to Wav2vec2 finetuning - Evaluation WER does not change

Screenshots of the codes for creating the dataset:

Screenshot from 2022-07-08 08-02-32.png

Screenshot from 2022-07-08 08-03-39.png

Screenshots of the finetuning code - the code is also copy pasted above but it doesn't appear well. so attaching screenshots here.

Screenshot from 2022-07-08 08-04-41.png
Screenshot from 2022-07-08 08-05-13.png
Screenshot from 2022-07-08 08-05-32.png
Screenshot from 2022-07-08 08-08-03.png

Sign up or log in to comment