Size mismatch error during train

#11
by taqwa92 - opened
This comment has been hidden
This comment has been hidden

Hey @taqwa92

The issue is with your target label sequences. Some of the label sequences have a length that exceeds the model’s maximum generation length. These must be very long sequences, as the maximum generation length is 448. This is the longest sequence the model is configured to handle (model.config.max_length).

We've got two options here:

  1. Filter any label sequences longer than max length
  2. Increase the models' max length

What we can do is compute the labels length of each target sequence:

def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute input length
    batch["input_length"] = len(batch["audio"])

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids

    # compute labels length
    batch["labels_length"] = len(batch["labels"])
    return batch

And then filter those that exceed the models maximum length:

MAX_DURATION_IN_SECONDS = 30.0
max_input_length = MAX_DURATION_IN_SECONDS * 16000

def filter_inputs(input_length):
    """Filter inputs with zero input length or longer than 30s"""
    return 0 < input_length < max_input_length

max_label_length = model.config.max_length

def filter_labels(labels_length):
    """Filter label sequences longer than max length (448)"""
    return labels_length < max_label_length

You can then apply the prepare_dataset function and the two new filter functions to your dataset common_voice as follows:

# pre-process
common_voice = common_voice.map(prepare_dataset, remove_columns= my_dataset.column_names["train"])
# filter by audio length
common_voice = common_voice.filter(filter_inputs, input_columns=["input_length"], remove_columns=["input_length"]
# filter by label length
common_voice = common_voice.filter(filter_labels, input_columns=["labels_length"], remove_columns=["labels_length"])

That should pre-process the dataset and remove any label sequences that are too long for the model.

Alternatively, we can change the model’s max length to any value we want:

model.config.max_length = 500

This will update the max length to 500 tokens. Make sure to do this before you filter for it to take effect:

max_label_length  = model.config.max_length = 500

def filter_labels(labels_length):
    """Filter label sequences longer than the new max length (500)"""
    return labels_length < max_label_length

Hope that helps!

alot of thanks for you prof @sanchit-gandhi , it really helps me

Sign up or log in to comment