training / flax /run_eval.py
supawichwac's picture
Saving train state of step 50
55f3766 verified
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluating a Whisper model on one or more evaluation datasets.
"""
# You can also adapt this script for your own speech recognition validation. Pointers for this are left as comments.
import logging
import os
import string
import sys
import time
from dataclasses import field
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import datasets
import evaluate
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import transformers
from datasets import Dataset, DatasetDict, IterableDatasetDict, load_dataset
from flax import jax_utils
from flax.jax_utils import pad_shard_unpad
from flax.training.common_utils import get_metrics, onehot
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
HfArgumentParser,
Seq2SeqTrainingArguments,
WhisperConfig,
WhisperFeatureExtractor,
WhisperProcessor,
WhisperTokenizerFast,
is_tensorboard_available,
is_wandb_available,
)
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from distil_whisper import FlaxWhisperForConditionalGeneration
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.27.0.dev0")
require_version(
"datasets>=1.18.0",
"To fix: pip install -r examples/flax/speech-recogintion/requirements.txt",
)
logger = logging.getLogger(__name__)
@flax.struct.dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={"help": ("Path to pretrained model or model identifier from huggingface.co/models")}
)
config_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained config name or path if not the same as model_name"},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
)
feature_extractor_name: Optional[str] = field(
default=None,
metadata={"help": "feature extractor name or path if not the same as model_name"},
)
processor_name: Optional[str] = field(
default=None,
metadata={"help": "processor name or path if not the same as model_name"},
)
cache_dir: Optional[str] = field(
default=None,
metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")},
)
use_fast_tokenizer: bool = field(
default=True,
metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")},
)
model_revision: str = field(
default="main",
metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")},
)
subfolder: str = field(
default="",
metadata={
"help": "In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can"
"specify the folder name here."
},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": (
"Will use the token generated when running `transformers-cli login`"
" (necessary to use this script with private models)."
)
},
)
dtype: Optional[str] = field(
default="float32",
metadata={
"help": (
"Floating-point format in which the model weights should be initialized"
" and trained. Choose one of `[float32, float16, bfloat16]`."
)
},
)
load_with_scan: Optional[bool] = field(
default=False,
metadata={
"help": (
"Whether to load the model with scan enabled. Required when the model was saved with scan enabled"
)
},
)
return_timestamps: bool = field(
default=False, metadata={"help": "Whether or not to predict timestamps in the generation step."}
)
@flax.struct.dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""
dataset_name: str = field(
default=None,
metadata={
"help": "The name of the dataset to use (via the datasets library). Load and combine "
"multiple datasets by separating dataset hours by a '+' symbol."
},
)
dataset_config_name: Optional[str] = field(
default=None,
metadata={"help": "The configuration name of the dataset to use (via the datasets library)."},
)
dataset_split_name: Optional[str] = field(
default=None,
metadata={"help": "The split name of the dataset to use (via the datasets library)."},
)
dataset_cache_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to cache directory for saving and loading datasets"},
)
overwrite_cache: bool = field(
default=False,
metadata={"help": "Overwrite the cached training and evaluation sets"},
)
preprocessing_num_workers: Optional[int] = field(
default=None,
metadata={"help": "The number of processes to use for the preprocessing."},
)
audio_column_name: str = field(
default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
)
text_column_name: str = field(
default=None,
metadata={"help": "The name of the dataset column containing the text data. Defaults to `text`."},
)
max_duration_in_seconds: float = field(
default=30.0,
metadata={"help": "Filter audio files that are longer than `max_duration_in_seconds` seconds"},
)
min_duration_in_seconds: float = field(
default=0.0,
metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"},
)
max_label_length: int = field(
default=128,
metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
)
pad_target_to_multiple_of: Optional[int] = field(
default=None,
metadata={
"help": (
"If set will pad the target sequence to a multiple of the provided"
" value. This is important to avoid triggering recompilations on TPU."
" If unspecified, will default to padding the targets to max length."
)
},
)
preprocessing_only: bool = field(
default=False,
metadata={
"help": (
"Whether to only do data preprocessing and skip training. This is"
" especially useful when data preprocessing errors out in distributed"
" training due to timeout. In this case, one should run the"
" preprocessing in a non-distributed setup with"
" `preprocessing_only=True` so that the cached datasets can"
" consequently be loaded in distributed training"
)
},
)
wandb_project: str = field(
default="distil-whisper",
metadata={"help": "The name of the wandb project."},
)
wandb_name: str = field(
default=None,
metadata={"help": "The name of the wandb run."},
)
wandb_job_type: str = field(
default="distil-whisper",
metadata={"help": "The name of the wandb job type."},
)
wandb_dir: str = field(
default=None,
metadata={"help": "The absolute path to save the wandb logs."},
)
save_code_to_wandb: bool = field(
default=False,
metadata={
"help": (
"Whether to save main script to wandb. This is valuable for improving"
" experiment reproducibility and to diff code across experiments in"
" the UI."
)
},
)
streaming: bool = field(
default=True,
metadata={"help": "Whether to use Datasets' streaming mode to load and the data."},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={"help": "For debugging purposes, truncate the number of eval examples to this value if set."},
)
log_audio: Optional[bool] = field(
default=False,
metadata={"help": "For debugging purposes, record the audio samples as well as the ground truths / preds."},
)
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
"""
Shift label ids one token to the right.
"""
shifted_label_ids = np.zeros_like(label_ids)
shifted_label_ids[:, 1:] = label_ids[:, :-1]
shifted_label_ids[:, 0] = decoder_start_token_id
return shifted_label_ids
@flax.struct.dataclass
class FlaxDataCollatorSpeechSeq2SeqWithPadding:
"""
Data collator that will dynamically pad the inputs received.
Args:
processor ([`Wav2Vec2Processor`])
The processor used for proccessing the data.
decoder_start_token_id (:obj: `int`)
The begin-of-sentence of the decoder.
input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
See above for details.
max_target_length (:obj:`int`, `optional`):
Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
log_audio (:obj:`bool`):
Whether we're logging audio samples as part of our eval. If so, will forward on the audio samples to the batch.
audio_column_name (:obj:`str`):
Name of the audio column in the dataset. Only relevant if logging audio samples.
"""
processor: Any
decoder_start_token_id: int
input_padding: Union[bool, str] = "max_length"
target_padding: Union[bool, str] = "max_length"
max_target_length: Optional[int] = None
log_audio: Optional[bool] = False
audio_column_name: Optional[str] = "audio"
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
model_input_name = self.processor.model_input_names[0]
# dataloader returns a list of features which we convert to a dict
input_features = {model_input_name: [feature[model_input_name] for feature in features]}
label_features = {"input_ids": [feature["labels"] for feature in features]}
# reformat list to dict and set to pytorch format
batch = self.processor.feature_extractor.pad(
input_features,
padding=self.input_padding,
return_tensors="np",
)
labels_batch = self.processor.tokenizer.pad(
label_features,
max_length=self.max_target_length,
padding=self.target_padding,
return_tensors="np",
)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
labels = labels_batch["input_ids"]
if (labels[:, 0] == self.decoder_start_token_id).all().item():
labels = labels[:, 1:]
labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
# replace padding with -100 to ignore correctly when computing the loss
labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
labels = labels.filled(fill_value=-100)
batch["labels"] = labels
batch["decoder_input_ids"] = decoder_input_ids
if self.log_audio:
audio_samples = [feature[self.audio_column_name] for feature in features]
batch["audio"] = audio_samples
return batch
def get_data_loader(
dataset: Dataset,
batch_size: int,
data_collator: FlaxDataCollatorSpeechSeq2SeqWithPadding,
dataloader_num_workers: int = 0,
pin_memory: bool = True,
) -> DataLoader:
"""
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
Args:
dataset (Dataset): dataset from which to load the data.
batch_size (int): how many samples per batch to load.
data_collator (FlaxDataCollatorSpeechSeq2SeqWithPadding, optional): merges a list of samples to form a
mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
dataloader_num_workers (int, optional): how many subprocesses to use for data
loading. ``0`` means that the data will be loaded in the main process.
(default: ``0``)
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
into device/CUDA pinned memory before returning them. If your data elements
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
see the example below.
"""
data_loader = DataLoader(
dataset,
batch_size=batch_size,
drop_last=False,
pin_memory=pin_memory,
collate_fn=data_collator,
num_workers=dataloader_num_workers,
)
return data_loader
def write_metric(summary_writer, eval_metrics, step, prefix="eval"):
for metric_name, value in eval_metrics.items():
summary_writer.scalar(f"{prefix}/{metric_name}", value, step)
def write_wandb_metric(wandb_logger, metrics, train_time, prefix):
log_metrics = {}
for k, v in metrics.items():
log_metrics[f"{prefix}/{k}"] = v
log_metrics[f"{prefix}/time"] = train_time
wandb_logger.log(log_metrics) # TODO(SG): bug with wandb means we can't log the step count
def convert_audio_to_wandb(wandb_logger, audio):
return wandb_logger.Audio(audio["array"][:, np.newaxis], sample_rate=audio["sampling_rate"])
def write_wandb_pred(
wandb_logger,
eval_audios,
pred_str,
label_str,
norm_pred_str,
norm_label_str,
prefix="eval",
num_lines=200000,
):
columns = ["Target", "Pred", "Norm Target", "Norm Pred"]
# convert str data to a wandb compatible format
str_data = [[label_str[i], pred_str[i], norm_label_str[i], norm_pred_str[i]] for i in range(len(pred_str))]
if len(eval_audios) > 0:
columns.insert(0, "Audio")
str_data = [
[
convert_audio_to_wandb(wandb_logger, eval_audios[i]),
*str_data[i],
]
for i in range(len(pred_str))
]
# log as a table with the appropriate headers
wandb_logger.log(
{f"{prefix}/all_predictions": wandb_logger.Table(columns=columns, data=str_data[:num_lines])},
)
# log incorrect normalised predictions
str_data = np.asarray(str_data)
str_data_incorrect = str_data[str_data[:, -2] != str_data[:, -1]]
# log as a table with the appropriate headers
wandb_logger.log(
{f"{prefix}/incorrect_predictions": wandb_logger.Table(columns=columns, data=str_data_incorrect[:num_lines])},
)
def convert_dataset_str_to_list(
dataset_names, dataset_config_names, splits=None, text_column_names=None, dataset_hours=None, default_split="train"
):
if isinstance(dataset_names, str):
dataset_names = dataset_names.split("+")
# we assume that all the datasets we're using derive from the distil-whisper org on the Hub - prepend the org name if necessary
for i in range(len(dataset_names)):
ds_name = dataset_names[i]
dataset_names[i] = f"distil-whisper/{ds_name}" if "/" not in ds_name else ds_name
dataset_config_names = dataset_config_names.split("+")
splits = splits.split("+") if splits is not None else None
text_column_names = text_column_names.split("+") if text_column_names is not None else None
dataset_hours = dataset_hours.split("+") if dataset_hours is not None else None
# basic checks to ensure we've got the right number of datasets/configs/splits/columns/probs
if len(dataset_names) != len(dataset_config_names):
raise ValueError(
f"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(dataset_config_names)} configs."
)
if splits is not None and len(splits) != len(dataset_names):
raise ValueError(
f"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits."
)
if text_column_names is not None and len(text_column_names) != len(dataset_names):
raise ValueError(
f"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and"
f" {len(text_column_names)} text column names."
)
if dataset_hours is not None:
if len(dataset_hours) != len(dataset_names):
raise ValueError(
f"Ensure one probability is passed for each dataset, got {len(dataset_names)} datasets and "
f"{len(dataset_hours)} hours."
)
dataset_hours = [float(ds_hours) for ds_hours in dataset_hours]
else:
dataset_hours = [None] * len(dataset_names)
text_column_names = (
text_column_names if text_column_names is not None else ["text" for _ in range(len(dataset_names))]
)
splits = splits if splits is not None else [default_split for _ in range(len(dataset_names))]
dataset_names_dict = []
for i, ds_name in enumerate(dataset_names):
dataset_names_dict.append(
{
"name": ds_name,
"config": dataset_config_names[i],
"split": splits[i],
"text_column_name": text_column_names[i],
"hours": dataset_hours[i],
}
)
return dataset_names_dict
class FlaxWhisperFeatureExtractor(WhisperFeatureExtractor):
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
"""
Compute the log-mel spectrogram of the provided audio using torch filters. Using the torch implementation
computes stft filter banks approx 5x faster than its numpy counterpart, which is the native implementation
in transformers, and matches to within 1e-5 abs tolerance.
"""
waveform = torch.from_numpy(waveform).type(torch.float32)
window = torch.hann_window(self.n_fft)
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
mel_spec = mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec.numpy()
def main():
# 1. Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your JAX/Flax versions.
send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")
# 2. Setup logging
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
)
# Set the verbosity to info of the Transformers logger.
# We only want one process per machine to log things on the screen.
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
if jax.process_index() == 0:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
logger.info("Evaluation parameters %s", training_args)
# Enable tensorboard only on the master node
has_tensorboard = is_tensorboard_available()
if "tensorboard" in training_args.report_to:
if has_tensorboard and jax.process_index() == 0:
try:
from flax.metrics.tensorboard import SummaryWriter
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
except ImportError as ie:
has_tensorboard = False
logger.warning(
"Unable to display metrics through TensorBoard because some" f" package are not installed: {ie}"
)
else:
logger.warning(
"Unable to display metrics through TensorBoard because the package is"
" not installed: Please run `pip install tensorboard` to enable."
)
# Enable wandb only on the master node
has_wandb = is_wandb_available()
if "wandb" in training_args.report_to:
if has_wandb and jax.process_index() == 0:
import wandb as wandb_logger
# Set up wandb run
wandb_logger.init(
project=data_args.wandb_project,
name=data_args.wandb_name,
job_type=data_args.wandb_job_type,
dir=data_args.wandb_dir,
save_code=data_args.save_code_to_wandb,
)
else:
logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")
# 3. Load dataset
raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
# Convert lists of dataset names/configs/splits to a dict
# names: "librispeech_asr+gigaspeech", configs: "all+l", splits: "validation.clean+validation"
# -> [{"name: "librispeech_asr": "config": "all", "split": "validation.clean"}, {"name: "gigaspeech": "config": "l", "split": "validation"}
dataset_names_dict = convert_dataset_str_to_list(
data_args.dataset_name,
data_args.dataset_config_name,
splits=data_args.dataset_split_name,
text_column_names=data_args.text_column_name,
)
if len(dataset_names_dict) == 1:
# load a single eval set
dataset_dict = dataset_names_dict[0]
raw_datasets["eval"] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
if dataset_dict["text_column_name"] not in list(raw_datasets["eval"].features.keys()):
raise ValueError(
f"--text column name {dataset_dict['text_column_name']} not found in the evaluation "
f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
f"for the target text. Should be one of {' '.join(list(raw_datasets['eval'].features.keys()))}"
)
if dataset_dict["text_column_name"] != "text":
raw_datasets["eval"] = raw_datasets["eval"].rename_column(dataset_dict["text_column_name"], "text")
else:
# load multiple eval sets
for dataset_dict in tqdm(dataset_names_dict, desc="Loading datasets..."):
# Clean-up the dataset name for pretty logging
# ("distil-whisper/librispeech_asr", "validation.clean") -> "librispeech_asr/validation-clean"
pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
raw_datasets[pretty_name] = load_dataset(
dataset_dict["name"],
dataset_dict["config"],
split=dataset_dict["split"],
cache_dir=data_args.dataset_cache_dir,
use_auth_token=True if model_args.use_auth_token else None,
streaming=data_args.streaming,
)
if dataset_dict["text_column_name"] not in list(raw_datasets[pretty_name].features.keys()):
raise ValueError(
f"`--text_column_name` {dataset_dict['text_column_name']} not found in the evaluation "
f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
f"for the target text. Should be one of {' '.join(list(raw_datasets[pretty_name].features.keys()))}"
)
if dataset_dict["text_column_name"] != "text":
raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
dataset_dict["text_column_name"], "text"
)
# 5. Load pretrained model, tokenizer, and feature extractor
config = WhisperConfig.from_pretrained(
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(
(model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
tokenizer = WhisperTokenizerFast.from_pretrained(
(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
processor = WhisperProcessor.from_pretrained(
(model_args.processor_name if model_args.processor_name else model_args.model_name_or_path),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
)
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
config=config,
dtype=getattr(jnp, model_args.dtype),
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
_do_init=False,
subfolder=model_args.subfolder,
# use_scan=model_args.load_with_scan, # Model might have (erroneously) been saved with scan still enabled
)
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
# disable scan if necessary (makes the inference step faster)
if model_args.load_with_scan:
model.disable_scan() # to disable scan in the nn.Module
params = model.convert_scan_to_unroll(params) # to convert the scan params to unrolled
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
# so we just need to set the correct target sampling rate.
raw_datasets = raw_datasets.cast_column(
data_args.audio_column_name,
datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
)
# 7. Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the targets.
max_label_length = (
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
)
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
dataloader_num_workers = training_args.dataloader_num_workers
model_input_name = feature_extractor.model_input_names[0]
normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
if data_args.max_eval_samples is not None:
for split in raw_datasets:
raw_datasets[split] = (
raw_datasets[split].take(data_args.max_eval_samples)
if data_args.streaming
else raw_datasets[split].select(range(data_args.max_eval_samples))
)
def prepare_dataset(batch):
# process audio
sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length
batch[model_input_name] = inputs.get(model_input_name)[0]
# process targets
input_str = batch["text"]
batch["labels"] = tokenizer(input_str, max_length=max_label_length, truncation=True).input_ids
return batch
vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
for split in raw_datasets:
raw_datasets_features = list(raw_datasets[split].features.keys())
if data_args.log_audio:
# if logging audio samples preserve the audio column when mapping the dataset
raw_datasets_features.remove(audio_column_name)
map_fn = partial(
raw_datasets[split].map,
function=prepare_dataset,
remove_columns=raw_datasets_features,
)
vectorized_datasets[split] = (
map_fn(num_proc=num_workers, desc="preprocess eval dataset")
if not data_args.streaming
else map_fn() # In streaming, we can't run multiproc - errors out if we try to
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with `args.preprocessing_only` since there will mostly likely
# be a timeout when running the script in distributed mode.
# In a second step `args.preprocessing_only` can then be set to `False` to load the
# cached dataset
if data_args.preprocessing_only:
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
return
# 8. Load Metric
metric = evaluate.load("wer")
# convention is that we space all punctuation *except* apostrophes
all_punctuation = list(string.punctuation.replace("'", ""))
return_timestamps = model_args.return_timestamps
def compute_metrics(preds, labels):
# replace padded labels by the padding token
for idx in range(len(labels)):
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
# we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
# space punctuation for orthographic WER (c.f. ESB paper https://arxiv.org/abs/2210.13352)
spaced_pred_str = [
pred_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(pred_str))
]
spaced_label_str = [
label_str[i].replace(punctuation, f" {punctuation} ")
for punctuation in all_punctuation
for i in range(len(label_str))
]
wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
# normalize everything and re-compute the WER
norm_pred_str = [normalizer(pred) for pred in pred_str]
norm_label_str = [normalizer(label) for label in label_str]
# for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str
data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
input_padding="longest",
target_padding="max_length",
max_target_length=max_label_length,
log_audio=data_args.log_audio,
)
# Store some constants
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
eval_batch_size = per_device_eval_batch_size * jax.device_count()
# label smoothed cross entropy
def loss_fn(logits, labels, label_smoothing_factor=0.0):
"""
The label smoothing implementation is adapted from Flax's official example:
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
"""
vocab_size = logits.shape[-1]
confidence = 1.0 - label_smoothing_factor
low_confidence = (1.0 - confidence) / (vocab_size - 1)
normalizing_constant = -(
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
)
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
loss = optax.softmax_cross_entropy(logits, soft_labels)
loss = loss - normalizing_constant
# ignore padded tokens from loss, i.e. where labels are not set to -100
padding_mask = labels >= 0
loss = loss * padding_mask
loss = loss.sum()
num_labels = padding_mask.sum()
return loss, num_labels
# Define eval fn
def eval_step(params, batch, label_smoothing_factor=0.0):
labels = batch.pop("labels")
logits = model(**batch, params=params, freeze_encoder=True, train=False)[0]
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
num_labels = jax.lax.psum(num_labels, "batch")
# true loss = total loss / total samples
loss = jax.lax.psum(loss, "batch")
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
metrics = {"loss": loss}
return metrics
# Define generation function
num_beams = (
training_args.generation_num_beams
if training_args.generation_num_beams is not None
else model.config.num_beams
)
# forcing the language and task tokens helps the flax teacher model in its generations
gen_kwargs = {
"max_length": max_label_length,
"num_beams": num_beams,
"language": "<|en|>",
"task": "transcribe",
"return_timestamps": return_timestamps,
}
def generate_step(params, batch):
output_ids = model.generate(
batch[model_input_name],
attention_mask=batch.get("attention_mask"),
params=params,
freeze_encoder=True,
**gen_kwargs,
)
return output_ids.sequences
# Create parallel version of the eval and generate step
p_eval_step = jax.pmap(
partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor),
"batch",
)
p_generate_step = jax.pmap(generate_step, "batch")
# Replicate params on each device
params = jax_utils.replicate(params)
def eval_step(split="eval"):
# ======================== Evaluating ==============================
eval_metrics = []
eval_preds = []
eval_labels = []
eval_audios = []
eval_start = time.time()
eval_loader = get_data_loader(
vectorized_datasets[split],
batch_size=eval_batch_size,
data_collator=data_collator,
dataloader_num_workers=dataloader_num_workers,
)
for batch in tqdm(eval_loader, desc=f"Evaluating {split}..."):
# Model forward
labels = batch["labels"]
if data_args.log_audio:
eval_audios.extend(batch.pop("audio"))
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
params, batch.data, min_device_batch=per_device_eval_batch_size
)
eval_metrics.append(metrics)
# generation
if training_args.predict_with_generate:
generated_ids = pad_shard_unpad(p_generate_step)(
params, batch.data, min_device_batch=per_device_eval_batch_size
)
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
eval_labels.extend(labels)
eval_time = time.time() - eval_start
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute WER metric
wer_desc = ""
if training_args.predict_with_generate:
wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(eval_preds, eval_labels)
eval_metrics.update(wer_metric)
wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
# Print metrics
logger.info(f"Eval Loss: {eval_metrics['loss']} | {wer_desc})")
# Save metrics
if has_tensorboard and jax.process_index() == 0 and "tensorboard" in training_args.report_to:
write_metric(summary_writer, eval_metrics, model_args.step, prefix=split)
if has_wandb and jax.process_index() == 0 and "wandb" in training_args.report_to:
write_wandb_metric(wandb_logger, eval_metrics, eval_time, prefix=split)
if training_args.predict_with_generate:
write_wandb_pred(
wandb_logger, eval_audios, pred_str, label_str, norm_pred_str, norm_label_str, prefix=split
)
logger.info("***** Running Eval *****")
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
logger.info(f" Total eval batch size (w. parallel & distributed) = {eval_batch_size}")
for split in vectorized_datasets:
eval_step(split=split)
if __name__ == "__main__":
main()